Hello, I want to be able to check if two models have the same weights in their layers. After poking around, I couldn’t find a function that did this, so I implemented my own.
def compareModelWeights(model_a, model_b):
module_a = model_a._modules
module_b = model_b._modules
if len(list(module_a.keys())) != len(list(module_b.keys())):
return False
a_modules_names = list(module_a.keys())
b_modules_names = list(module_b.keys())
for i in range(len(a_modules_names)):
layer_name_a = a_modules_names[i]
layer_name_b = b_modules_names[i]
if layer_name_a != layer_name_b:
return False
layer_a = module_a[layer_name_a]
layer_b = module_b[layer_name_b]
if (
(type(layer_a) == nn.Module) or (type(layer_b) == nn.Module) or
(type(layer_a) == nn.Sequential) or (type(layer_b) == nn.Sequential)
):
if not compareModelWeights(layer_a, layer_b):
return False
if hasattr(layer_a, 'weight') and hasattr(layer_b, 'weight'):
if not torch.equal(layer_a.weight.data, layer_b.weight.data):
return False
return True
It’s recursive because modules can have a tree structure.
If there is already a function that does this, I would love to use it in my code. If not, this might be a useful addition. Please let me know if there are any errors.