Your model train by pytorch0.3.x, but run in pytorch > 0.4.0.
Change the parameter of BatchNorm2d by yourself.
For example, define the function
def recursion_change_bn(module):
if isinstance(module, torch.nn.BatchNorm2d):
module.track_running_stats = 1
else:
for i, (name, module1) in enumerate(module._modules.items()):
module1 = recursion_change_bn(module1)
return module
and
use it when you load model
check_point = torch.load(check_point_file_path)
model = check_point['net']
for i, (name, module) in enumerate(model._modules.items()):
module = recursion_change_bn(model)
model.eval()
I have ran 0.3.1 model in pytorch0.4.1. and pytorch1.0.0.