Load_state_dict() changes the outside variable!

I don’t understand why the model.load_state_dict() would change the variable that is passed to initialize the weight of the model (in the following example, the a), and this behavior has never mentioned in the documentation.

pytorch version: 0.4.0

Here is the toy example:

class MINIModel(nn.Module):
    def __init__(self,h0=None):
        super(MINIModel, self).__init__()
        self.h = nn.Linear(3,4)
        if not h0 is None:
            self.h.weight.data = h0
    def forward(self,x):
        return self.h(x)
m = MINIModel()
torch.save(m.state_dict(),  'test/params.ckpt')
print('saved model weight')
print(m.h.weight.data)
a = torch.ones(4,3)
m = MINIModel(a)
print('before loading, a is:')
print(a)
m.load_state_dict(torch.load('test/params.ckpt'))
print('after loading, a is:')
print(a) #HERE the a should not get changed
print('after loading, weight is:')
print(m.h.weight.data)

which outputs:

saved model weight
tensor([[-0.5034,  0.0249, -0.4539],
        [-0.3079, -0.2559,  0.5453],
        [ 0.3883, -0.4435, -0.3697],
        [ 0.1465,  0.1131, -0.1957]])
before loading, a is:
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
after loading, a is:
tensor([[-0.5034,  0.0249, -0.4539],
        [-0.3079, -0.2559,  0.5453],
        [ 0.3883, -0.4435, -0.3697],
        [ 0.1465,  0.1131, -0.1957]])
after loading, weight is:
tensor([[-0.5034,  0.0249, -0.4539],
        [-0.3079, -0.2559,  0.5453],
        [ 0.3883, -0.4435, -0.3697],
        [ 0.1465,  0.1131, -0.1957]])

After loading, the a should keep to be ones, should not it? Can anyone help explain this?

Thank you very much.

This issue has nothing to do with Load_state_dict(). The problem is when you created your model using MINIModel(a), you’ve set the reference of a as the data for the h variable via the following lines:

    if not h0 is None:
        self.h.weight.data = h0

Basically, any change to self.h.weight.data will be reflected in a because they share the same data. It can be reproduced by the following code:

class MINIModel(nn.Module):
    def __init__(self,h0=None):
        super(MINIModel, self).__init__()
        self.h = nn.Linear(3,4)
        if not h0 is None:
            self.h.weight.data = h0
    def forward(self,x):
        return self.h(x)
a = torch.ones(4,3)
m = MINIModel(a)
print('before loading, a is:')
print(a)
m.h.weight.data.add_(1)
print('after changing h.weight.data, a is:')
print(a)

To fix this, you can change:

self.h.weight.data = h0

to:

self.h.weight.data = h0.clone()

1 Like