Changes to one nn model overwrite another model

I create two neural network models m1 and m2, however when I modify the weights in m1, the weights in m2 are modified as well. I suspected I am not instantiating them properly, however they do have different memory addresses. I have also tried creating a custom class based on nn.module instead of using nn.sequential but the issue remains. Any help much appreciated.
The issue is demonstrated as follows:

from collections import namedtuple, OrderedDict
import torch
import torch.nn as nn
# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def init_weights(m):
    if type(m) == nn.Linear:
        m.weight.data.fill_(0.2)
        m.bias.data.fill_(1.0)
        
D_in = 4
H = 8
D_out = 2    
net_spec = OrderedDict([('0', torch.nn.Linear(D_in, H)),
                    ('1', torch.nn.ReLU()),
                    ('2', torch.nn.Linear(H, D_out))])

m1 = torch.nn.Sequential(net_spec)
m2 = torch.nn.Sequential(net_spec)
print('object addresses:')
print('   m1   :',hex(id(m1)))
print('   m2   :',hex(id(m2)))
print("\nm1 weights:\n",m1.state_dict()['0.weight'])
print("\nm2 weights:\n",m2.state_dict()['0.weight'])
m1.apply(init_weights)
print('m1 weights updated')
print("\nm1 weights:\n",m1.state_dict()['0.weight'])
print("\nm2 weights:\n",m2.state_dict()['0.weight'])

output:

object addresses:
   m1   : 0x7f893cbe91d0
   m2   : 0x7f893cbe92e8

m1 weights:
 tensor([[-0.4047,  0.1752, -0.1989, -0.2917],
        [-0.0466, -0.0044,  0.1561,  0.2465],
        [ 0.0700, -0.3357, -0.4978, -0.0837],
        [-0.0248, -0.2826, -0.4564, -0.0516],
        [-0.0024, -0.1732, -0.4144, -0.3790],
        [-0.3075,  0.2768, -0.2676,  0.4495],
        [-0.4929,  0.1328, -0.3153, -0.4591],
        [-0.0597,  0.3718,  0.0522,  0.0899]])

m2 weights:
 tensor([[-0.4047,  0.1752, -0.1989, -0.2917],
        [-0.0466, -0.0044,  0.1561,  0.2465],
        [ 0.0700, -0.3357, -0.4978, -0.0837],
        [-0.0248, -0.2826, -0.4564, -0.0516],
        [-0.0024, -0.1732, -0.4144, -0.3790],
        [-0.3075,  0.2768, -0.2676,  0.4495],
        [-0.4929,  0.1328, -0.3153, -0.4591],
        [-0.0597,  0.3718,  0.0522,  0.0899]])
m1 weights updated

m1 weights:
 tensor([[ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000]])

m2 weights:
 tensor([[ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000],
        [ 0.2000,  0.2000,  0.2000,  0.2000]])

Solution Found:
I found two ways to create the second model m2 that does not have the issue.

m2 = torch.nn.Sequential(copy.deepcopy(net_spec))

or

m2 = copy.deepcopy(m1)

It seems strange to me that the instantiation of the network is tied to the instance OrderedDict() passed to the class to instantiate the model. If anyone can enlighten me it would be appreciated.

You’re right, it is related to the OrderedDict you’re using to instantiate the nn.Sequential modules. Because PyTorch is imperative whenever you write PyTorch code it is executed immediately, so when yo do

net_spec = OrderedDict([('0', torch.nn.Linear(D_in, H)),
                    ('1', torch.nn.ReLU()),
                    ('2', torch.nn.Linear(H, D_out))])

you’re actually instantiating a linear module with dims. [D_in, H], a relu module and another linear with dims [H, D_out]. These modules are grouped in your OrderedDict and when you pass this to nn.Sequential what happens is that you wrap those SAME modules that are already created in a new Sequential module. So the two Sequential modules are indeed different instances, but they both wrap the same underlying Linear modules:

 id(net_spec['0']) == id(m1[0]) == id(m2[0])  # Same nn.Linear instance
1 Like

Additionally python (just like the pytorch-C-backend) used something similar to call by reference. Thus the sequential class only receives a reference of the OrderedDict which only holds references of the layers.

1 Like