it is possible that your training in general is unstable, so BatchNorm’s running_mean and running_var dont represent true batch statistics.
http://pytorch.org/docs/master/nn.html?highlight=batchnorm#torch.nn.BatchNorm1d
Try the following:
- change the
momentumterm in BatchNorm constructor to higher. - before you set
model.eval(), run a few inputs throughmodel(just forward pass, you dont need to backward). This will help stabilize the running_mean / running_std values.
Hope this helps.