I found this to be a better approach (also compares batch_norm layers for their running_mean and running_var params) -
def compare_models(model_1, model_2):
models_differ = 0
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
if torch.equal(key_item_1[1], key_item_2[1]):
pass
else:
models_differ += 1
if (key_item_1[0] == key_item_2[0]):
print('Mismtach found at', key_item_1[0])
else:
raise Exception
if models_differ == 0:
print('Models match perfectly! :)')