I don’t understand why the model.load_state_dict() would change the variable that is passed to initialize the weight of the model (in the following example, the a), and this behavior has never mentioned in the documentation.
pytorch version: 0.4.0
Here is the toy example:
class MINIModel(nn.Module):
def __init__(self,h0=None):
super(MINIModel, self).__init__()
self.h = nn.Linear(3,4)
if not h0 is None:
self.h.weight.data = h0
def forward(self,x):
return self.h(x)
m = MINIModel()
torch.save(m.state_dict(), 'test/params.ckpt')
print('saved model weight')
print(m.h.weight.data)
a = torch.ones(4,3)
m = MINIModel(a)
print('before loading, a is:')
print(a)
m.load_state_dict(torch.load('test/params.ckpt'))
print('after loading, a is:')
print(a) #HERE the a should not get changed
print('after loading, weight is:')
print(m.h.weight.data)
which outputs:
saved model weight
tensor([[-0.5034, 0.0249, -0.4539],
[-0.3079, -0.2559, 0.5453],
[ 0.3883, -0.4435, -0.3697],
[ 0.1465, 0.1131, -0.1957]])
before loading, a is:
tensor([[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.]])
after loading, a is:
tensor([[-0.5034, 0.0249, -0.4539],
[-0.3079, -0.2559, 0.5453],
[ 0.3883, -0.4435, -0.3697],
[ 0.1465, 0.1131, -0.1957]])
after loading, weight is:
tensor([[-0.5034, 0.0249, -0.4539],
[-0.3079, -0.2559, 0.5453],
[ 0.3883, -0.4435, -0.3697],
[ 0.1465, 0.1131, -0.1957]])
After loading, the a should keep to be ones, should not it? Can anyone help explain this?
Thank you very much.