Training with BatchNorm in pytorch

I’m wondering if I need to do anything special when training with BatchNorm in pytorch. From my understanding the gamma and beta parameters are updated with gradients as would normally be done by an optimizer. However, the mean and variance of the batches are updated slowly using momentum.

  1. So do we need to specify to the optimizer when the mean and variance parameters are updated, or does pytorch automatically take care of this?
  2. Is there a way to access the mean and variance of the BN layer so that I can make sure it was changing while I trained the model.

If needed here is my model and training procedure:

def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0.):
    "Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`."
    layers = [nn.BatchNorm1d(n_in)] if bn else []
    if p != 0: layers.append(nn.Dropout(p))
    layers.append(nn.Linear(n_in, n_out))

    return nn.Sequential(*layers)

class Model(nn.Module):
    def __init__(self, i, o, h=()):
        super().__init__()
        
        nodes = (i,) + h + (o,)
        self.layers = nn.ModuleList([bn_drop_lin(i,o, p=0.5) 
                                     for i, o in zip(nodes[:-1], nodes[1:])])
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        for layer in self.layers[:-1]:
            x = F.relu(layer(x))
            
        return self.layers[-1](x)

Training:

for i, data in enumerate(trainloader):
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

This is a x-post from https://stackoverflow.com/questions/57865112/training-with-batchnorm-in-pytorch

  1. No, since the running estimates are buffers and do not require gradients, you don’t have to pass them to the optimizer.

  2. You can access them via my_bn_layer.running_mean and my_bn_layer.running_var.