You could debug this by:
def areTorchModulesEqual(module1, module2):
for index, (p1, p2) in enumerate(zip(module1.parameters(), module2.parameters())):
if p1.data.ne(p2.data).sum() > 0:
return False
return True
def whichModulesHaveBeenUpdated(model1_list, model2_list):
for index, module1 in enumerate(model1_list):
module2 = model2_list[index]
if areTorchModulesEqual(module1, module2):
print("At Index " + str(index) + " Module1 and Module2 are equal, below is the printed module")
print(module1)
print("\n")
else:
print("At index " + str(index) + " Module1 and Module2 are not equal, below is the printed module")
print(module1)
print("\n")
print("End of Modules")
Also does your model have batch-normalization or dropout?
It could also be an inconsistency with your test data iterator.