How to run the model to only update the Batch normalization statistics

Hi,
The statistics in the Batch normalization can be obtained in two ways:
1, using “track_running_stats”, then the average one over the training is used.
2, not using it, then the statistics of its individual batch is used.
I wonder how it can be done as in the original paper, “running over the dataset again and get the statistics”. In other words, how to “only update the statistics without updating other parameters” when running the model?

Thanks.

1 Like

Since the running stats will be updated during the forward pass, you could set the model to model.train() (or just the BatchNorm layer if you wish), and just do forward passes on your data without calculating the backward pass or update any gradients.

2 Likes

Thanks for the solution.

So if I want to use the individual batch statistics for the validation (during the training process), and to obtain the statistics over the dataset in the end of the training, is the following process correct?
1, set “track_running_stats” to true, allow for keeping the average of the statistics.
2, train the model
3, set the model to .val(), but set batch norm to .train(), to validate the model using individual statistics? This does not work since it will update the statistics using the validation data. One trick is to set the saved average statistics to its initial state (0 or 1) before the final updating, so the data does not show effect on the final statistics. Is there any easy way?
4, set the model to .train(), but don’t do backward pass to estimate the statistics for the whole dataset. (here the statistics is changed from the average in the training. Any way to avoid this?)

How to correctly implement “use the individual batch statistics for the validation (during the training process), and to obtain the statistics over the dataset in the end of the training”?

Is there any way to set “track_running_stats” during the training processing similarly as .train() and .val() option?

Thanks.

I don’t really understand this question.
You would like to use individual batch statistics during the validation while use the running stats during training?
If you set the model to .eval(), the running stats will be used, if track_running_stats=True. Otherwise the batch stats will be used.

Do you want the running average of the dataset statistics or the “global” stats?
If the latter, you could calculate them offline and just set the running stats to these values.

1 Like

For training, the statistics does not matter as it does not concern the training process.
For validation, I want to use the batch stats.
For test, I want to use the global stats over the whole dataset.

So the solution would be: set track_running_stats=False, calculate the global stats offline and assign it to the model parameters?
Is there any easy to calculate the global stats?
Thanks.

I think you should set track_running_stats=False, because you don’t want to use the running stats in any case.
You could calculate the global stats using this example:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(100, 3, 24, 24)
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

dataset = MyDataset()
loader = DataLoader(dataset, batch_size=5, shuffle=False, num_workers=1)

tmp_mean = 0
tmp_var = 0
nb_elems = 0
for data in loader:
    b, _, h, w = data.size()
    tmp_mean += data.sum(3).sum(2).sum(0)
    tmp_var += torch.pow(data, 2).sum(3).sum(2).sum(0)
    nb_elems += b*h*w

global_mean = tmp_mean / nb_elems
global_var = tmp_var / (nb_elems - 1) - global_mean**2

Okay. Thanks for the example.
I was actually wondering whether there is a simple way to directly get the batch stats during the forward pass, and the global stats can be simply obtained by averaging them. I mean the batch stats is already calculated and used even when setting track_running_stats=False.
Thanks.