Hello, I want to be able to check if two models have the same weights in their layers. After poking around, I couldn’t find a function that did this, so I implemented my own.
def compareModelWeights(model_a, model_b):
module_a = model_a._modules
module_b = model_b._modules
if len(list(module_a.keys())) != len(list(module_b.keys())):
a_modules_names = list(module_a.keys())
b_modules_names = list(module_b.keys())
for i in range(len(a_modules_names)):
layer_name_a = a_modules_names[i]
layer_name_b = b_modules_names[i]
if layer_name_a != layer_name_b:
layer_a = module_a[layer_name_a]
layer_b = module_b[layer_name_b]
(type(layer_a) == nn.Module) or (type(layer_b) == nn.Module) or
(type(layer_a) == nn.Sequential) or (type(layer_b) == nn.Sequential)
if not compareModelWeights(layer_a, layer_b):
if hasattr(layer_a, 'weight') and hasattr(layer_b, 'weight'):
if not torch.equal(layer_a.weight.data, layer_b.weight.data):
It’s recursive because modules can have a tree structure.
If there is already a function that does this, I would love to use it in my code. If not, this might be a useful addition. Please let me know if there are any errors.
you can get
model2.parameters() and use:
for p1, p2 = zip(model1.parameters(), model2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
Hey, just a small correction:
It should be:
I don’t think an abs is necessary (or wanted)… you probably want to distinguish between negative and positive values.
ne (not equal) returns Boolean values which should be simply summed.
I found this to be a better approach (also compares batch_norm layers for their running_mean and running_var params) -
def compare_models(model_1, model_2):
models_differ = 0
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
if torch.equal(key_item_1, key_item_2):
models_differ += 1
if (key_item_1 == key_item_2):
print('Mismtach found at', key_item_1)
if models_differ == 0:
print('Models match perfectly! :)')
why isn’t this or something like this a built in method?
Feature requests and PR are always welcome, if you think this feature is useful and often used.
Adding an answer since I spent a 1/2 hour trying to figure out the same thing to check whether a network was really updating (spoiler, it wasn’t).
The easiest way I found to directly compare two networks is to convert the state_dict to a string representation and compare them directly.
state_a = network.state_dict().__str__()
state_b = network.state_dict().__str__()
if state_a == state_b:
print("Network not updating.")