Weird BN configuration

Hello, I am currently researching on non-stationary data. And I finally found a high performance configuration for BN layers. But, the process of those looks quite weird. So I coudn’t figure it out why it shows the high performance.

I’m trying to finetune the BN parameters of the pre-trained model with simple entropy minimization loss.

In this process, I configured the BN layers as track_running_stats=False. Then, I didn’t initialize the running estimates as running_mean=None and running_var=None.

With this configurations, how does the BN layers work for input data in train/eval mode?

I’ve observed that the output of the model in train/eval mode are different in last prediction stage.

############### Configuration
model.train()
    
model.requires_grad_(False)

for m in model.modules():
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
    m.requires_grad_(True) ###### Affine parameters
    
    m.track_running_stats = False
    # m.running_mean = None
    # m.running_var = None

############### Fintuning
model.train()

optimizer.zero_grad()

outputs = model(x)

loss = softmax_entropy(outputs).mean(0)

loss.backward()

optimizer.step()


############### Predictions
"""
outputs1 and outputs2 are different even for the same input.
"""
model.train()
outputs1 = model(x)

model.eval()
outputs2 = model(x)

From the docs:

track_running_stats (bool) – a boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes. Default: True

Hello @ptrblck, thank you for your answering.

Following the documents, it initializes the running_mean and running_var as None when we set the track_running_stats=False of BN layers at first time.

But, I’m trying to configure the BN layers after the pre-training stage, which is done with if isinstance(...) in the for loop. So, it doesn’t enter to the init function, so that the running_mean and running_var don’t change.

I just change the track_running_stats argument. Thus, it only affects the forward process, I guess.

In addition, I found that the performance also changes if I set m.running_mean=None and m.running_var=None when configuring BN layers.

Are you manipulating track_running_stats after initializing the layer? If so, why are you not setting it during the initialization directly? If you want to keep using the batch stats for normalize the inputs you could leave the layer in training mode.

I just wanna try to fine-tune the parameters in the BN layers. So, I was trying to continuously update the parameters based on the pre-trained weight after loading the pre-trained model.
In addition, manipulating track_running_stats after initializing the layer is common in test-time adaptation algorithm. So, I followed those works.

Ref 1. tent/tent.py at master · DequanWang/tent · GitHub

Ref 2. NOTE/learner/note.py at main · TaesikGong/NOTE · GitHub

Unlike those, I just wanna fix the running estimates over all modes of the BN layers
(because the data I’ve used are temporally correlated, so that estimating batch statistics leads to the accumulation of inappropriate statistics)
However, previous works just keep track for input data or don’t track and initialize the running_mean/var as None.

But, in my case, I observed the severe performance degradation when I initialize the running_mean/var as None.
In contrast to that, I observed the significant performance improvement when I didn’t initialize the running_mean/var as None.

So, I wanna know how the BN layers work in train/eval mode when I manipulate those like track_running_stats=False and don’t initialize running_mean/var

for m in model.modules():
  if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
      m.requires_grad_(True)
      
      m.track_running_stats = False
      ### no init running estimates
      # m.running_mean = None
      # m.running_var = None