Check if models have same weights

Hello, I want to be able to check if two models have the same weights in their layers. After poking around, I couldn’t find a function that did this, so I implemented my own.

def compareModelWeights(model_a, model_b):
    module_a = model_a._modules
    module_b = model_b._modules
    if len(list(module_a.keys())) != len(list(module_b.keys())):
        return False
    a_modules_names = list(module_a.keys())
    b_modules_names = list(module_b.keys())
    for i in range(len(a_modules_names)):
        layer_name_a = a_modules_names[i]
        layer_name_b = b_modules_names[i]
        if layer_name_a != layer_name_b:
            return False
        layer_a = module_a[layer_name_a]
        layer_b = module_b[layer_name_b]
        if (
            (type(layer_a) == nn.Module) or (type(layer_b) == nn.Module) or
            (type(layer_a) == nn.Sequential) or (type(layer_b) == nn.Sequential)
            ):
            if not compareModelWeights(layer_a, layer_b):
                return False
        if hasattr(layer_a, 'weight') and hasattr(layer_b, 'weight'):
            if not torch.equal(layer_a.weight.data, layer_b.weight.data):
                return False
    return True

It’s recursive because modules can have a tree structure.
If there is already a function that does this, I would love to use it in my code. If not, this might be a useful addition. Please let me know if there are any errors.

3 Likes

you can get model1.parameters() and model2.parameters() and use:

for p1, p2 = zip(model1.parameters(), model2.parameters()):
    if p1.data.ne(p2.data).sum() > 0:
        return False
return True
13 Likes

Hey, just a small correction:
It should be:

2 Likes

I don’t think an abs is necessary (or wanted)… you probably want to distinguish between negative and positive values.

ne (not equal) returns Boolean values which should be simply summed.

I found this to be a better approach (also compares batch_norm layers for their running_mean and running_var params) -

def compare_models(model_1, model_2):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            pass
        else:
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                print('Mismtach found at', key_item_1[0])
            else:
                raise Exception
    if models_differ == 0:
        print('Models match perfectly! :)')
11 Likes

why isn’t this or something like this a built in method?

Feature requests and PR are always welcome, if you think this feature is useful and often used. :wink:

4 Likes

Adding an answer since I spent a 1/2 hour trying to figure out the same thing to check whether a network was really updating (spoiler, it wasn’t).

The easiest way I found to directly compare two networks is to convert the state_dict to a string representation and compare them directly.


state_a = network.state_dict().__str__()
loss.backward()
opt.step()
state_b = network.state_dict().__str__()
 
if state_a == state_b:
    print("Network not updating.")
3 Likes

Thanks for this!

PS: I think str(network.state_dict()) works too!