Batch Normalization - Disabling `track_running_stats` for some training batches, but not all?

I’m using a GAN structure where my discriminator trains on some labeled data batches, some unlabeled data batches, and some generated data batches. I’d like to keep track_running_stats turned on, but only for the labeled data. That is, during the unlabeled and generated batches, it just uses the statistics from the labeled training so far (as it does during an eval step). Is there a way I can easily disable the updating of the batch norm statistics for these batches while still having everything else about the network in train mode? Thank you for your time!

You could just set the BatchNorm layer into eval mode, either by indexing it directly or recursively.
Here is a small example for the latter approach:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(6)
        self.fc = nn.Linear(6*24*24, 2)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.bn1(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


def set_bn_eval(m):
    if isinstance(m, nn.modules.batchnorm._BatchNorm):
        m.eval()


def set_bn_train(m):
    if isinstance(m, nn.modules.batchnorm._BatchNorm):
        m.train()


model = MyModel()
model.apply(set_bn_eval)
3 Likes

Great! And thank you for the script along with it. Just to confirm, this is the only change moving from train to eval mode for batch normalization will have? That is, there isn’t another part of the batch normalization that will change when using eval mode?

2 Likes