'BatchNorm2d' object has no attribute 'track_running_stats'

Your model train by pytorch0.3.x, but run in pytorch > 0.4.0.
Change the parameter of BatchNorm2d by yourself.
For example, define the function

def recursion_change_bn(module):
    if isinstance(module, torch.nn.BatchNorm2d):
        module.track_running_stats = 1
    else:
        for i, (name, module1) in enumerate(module._modules.items()):
            module1 = recursion_change_bn(module1)
    return module

and
use it when you load model

check_point = torch.load(check_point_file_path)
model = check_point['net']
for i, (name, module) in enumerate(model._modules.items()):
    module = recursion_change_bn(model)
model.eval()

I have ran 0.3.1 model in pytorch0.4.1. and pytorch1.0.0.

7 Likes