it is possible that your training in general is unstable, so BatchNorm’s running_mean
and running_var
dont represent true batch statistics.
http://pytorch.org/docs/master/nn.html?highlight=batchnorm#torch.nn.BatchNorm1d
Try the following:
- change the
momentum
term in BatchNorm constructor to higher. - before you set
model.eval()
, run a few inputs throughmodel
(just forward pass, you dont need to backward). This will help stabilize the running_mean / running_std values.
Hope this helps.