This repository contains an implementation of Adaptive Batch Normalization (AdaBN) in PyTorch, a technique that adapts batch normalization statistics to better generalize across domain shifts.
Revisiting Batch Normalization For Practical Domain Adaptation
-
batchnorm_adapt.py
Main script that demonstrates AdaBN usage. -
utils.py
Contains the core AdaBN implementation including the hook class and statistics computation functions.
- Python 3.7+
- PyTorch 1.10+
- NumPy
Install the required packages:
pip install torch numpy
from utils import compute_bn_stats, replace_bn_stats
# Set model to eval mode
model.eval()
# Compute target domain statistics
bn_stats = compute_bn_stats(model, target_dataloader)
# Apply AdaBN
replace_bn_stats(model, bn_stats)
python batchnorm_adapt.py
AdaBN adapts models to new domains by:
- Computing BatchNorm statistics on target domain data
- Replacing source domain statistics with target domain statistics
- Keeping all learned weights unchanged
This requires no additional training or parameters.