Hello, I want to be able to check if two models have the same weights in their layers. After poking around, I couldn’t find a function that did this, so I implemented my own.
def compareModelWeights(model_a, model_b): module_a = model_a._modules module_b = model_b._modules if len(list(module_a.keys())) != len(list(module_b.keys())): return False a_modules_names = list(module_a.keys()) b_modules_names = list(module_b.keys()) for i in range(len(a_modules_names)): layer_name_a = a_modules_names[i] layer_name_b = b_modules_names[i] if layer_name_a != layer_name_b: return False layer_a = module_a[layer_name_a] layer_b = module_b[layer_name_b] if ( (type(layer_a) == nn.Module) or (type(layer_b) == nn.Module) or (type(layer_a) == nn.Sequential) or (type(layer_b) == nn.Sequential) ): if not compareModelWeights(layer_a, layer_b): return False if hasattr(layer_a, 'weight') and hasattr(layer_b, 'weight'): if not torch.equal(layer_a.weight.data, layer_b.weight.data): return False return True
It’s recursive because modules can have a tree structure.
If there is already a function that does this, I would love to use it in my code. If not, this might be a useful addition. Please let me know if there are any errors.