Check if models have same weights

Hello, I want to be able to check if two models have the same weights in their layers. After poking around, I couldn’t find a function that did this, so I implemented my own.

def compareModelWeights(model_a, model_b):
    module_a = model_a._modules
    module_b = model_b._modules
    if len(list(module_a.keys())) != len(list(module_b.keys())):
        return False
    a_modules_names = list(module_a.keys())
    b_modules_names = list(module_b.keys())
    for i in range(len(a_modules_names)):
        layer_name_a = a_modules_names[i]
        layer_name_b = b_modules_names[i]
        if layer_name_a != layer_name_b:
            return False
        layer_a = module_a[layer_name_a]
        layer_b = module_b[layer_name_b]
        if (
            (type(layer_a) == nn.Module) or (type(layer_b) == nn.Module) or
            (type(layer_a) == nn.Sequential) or (type(layer_b) == nn.Sequential)
            if not compareModelWeights(layer_a, layer_b):
                return False
        if hasattr(layer_a, 'weight') and hasattr(layer_b, 'weight'):
            if not torch.equal(,
                return False
    return True

It’s recursive because modules can have a tree structure.
If there is already a function that does this, I would love to use it in my code. If not, this might be a useful addition. Please let me know if there are any errors.


you can get model1.parameters() and model2.parameters() and use:

for p1, p2 = zip(model1.parameters(), model2.parameters()):
    if > 0:
        return False
return True

Hey, just a small correction:
It should be:

1 Like

I don’t think an abs is necessary (or wanted)… you probably want to distinguish between negative and positive values.

ne (not equal) returns Boolean values which should be simply summed.

I found this to be a better approach (also compares batch_norm layers for their running_mean and running_var params) -

def compare_models(model_1, model_2):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                print('Mismtach found at', key_item_1[0])
                raise Exception
    if models_differ == 0:
        print('Models match perfectly! :)')

why isn’t this or something like this a built in method?

Feature requests and PR are always welcome, if you think this feature is useful and often used. :wink:


Adding an answer since I spent a 1/2 hour trying to figure out the same thing to check whether a network was really updating (spoiler, it wasn’t).

The easiest way I found to directly compare two networks is to convert the state_dict to a string representation and compare them directly.

state_a = network.state_dict().__str__()
state_b = network.state_dict().__str__()
if state_a == state_b:
    print("Network not updating.")