Momentum causes NAN in batchnorm

Sorry if this isn’t right forum, I couldn’t find a topic that fits my question.

I created a simple NN class for fully connected layers with the option of adding batch normalization:

class FC_batchnorm_net(nn.Module):
    def __init__(self, input_sizes, output_size, batchnorm = False, activation = 'ReLU'):
        self.n_layers = len(input_sizes)
        self.fc_layers = nn.ModuleDict({'fc'+str(i): nn.Linear(input_sizes[i-1], input_sizes[i])
                          for i in range(1,self.n_layers)})
        self.fc_layers['fc'+str(self.n_layers)] = nn.Linear(input_sizes[self.n_layers-1], output_size)
        if batchnorm == False:
            self.batchnorm = False
            if batchnorm == True:
                self.batchnorm = nn.ModuleDict({'bn'+str(i): nn.BatchNorm1d(input_sizes[i])
                          for i in range(1,self.n_layers)})
                if len(batchnorm) != len(input_sizes):
                    print("If bacthnorm isn't set to True of False, it must be a boolean array of the same length as input_size!")
                self.batchnorm = nn.ModuleDict()
                for i in range(1,self.n_layers):
                    if batchnorm[i]:
                        self.batchnorm['bn'+str(i)] = nn.BatchNorm1d(input_sizes[i])
        if activation == 'ReLU':
            self.activation = nn.ReLU()

    def forward(self, x):
        out = x
        out = self.fc_layers['fc1'](out)
        if self.batchnorm != False:
            if 'bn1' in self.batchnorm.keys():
                out = self.batchnorm['bn1'](out)
        out = self.activation(out)

        for i in range(2, self.n_layers):
            out = self.fc_layers['fc'+str(i)](out)
            if self.batchnorm != False:
                if 'bn'+str(i) in self.batchnorm.keys():
                    out = self.batchnorm['bn'+str(i)](out)
            out = self.activation(out)
        out = self.fc_layers['fc'+str(self.n_layers)](out)
        return out

I created a basic network using this, with 3 FC layers and a BN layer for the two inner layers. However, when I try to train the network with a momentum higher than 0.1, the parameters of the second BN layers get NANs after a few batches, and I cannot figure out why. The inputs are simply vectors of normalized floats. Any ideas?

I haven’t seen such an effect before, but you could check the input activation stats during training and see if these are exploding.
Using a higher momentum value would increase the importance of the current activation stats while updating the running stats, so it seems these values might be large.

Thanks. For now, I bypassed the issue by doing a hyperparameter search from scratch after adding the batchnorm layers, but I’ll keep that in ind if I run into the issue again.