I tried to run the following code and the output confused me.
Code is:
class test_model(nn.Module):
def __init__(self):
super(test_model, self).__init__()
self.fc = nn.Linear(10, 20)
def forward(self, x):
x = self.fc(x)
return x
def model_sum(params):
# return sum of parameters of a model
temp = 0
for value in params.values():
temp += value.sum().item()
return temp
model = test_model()
params = model.state_dict()
print(model_sum(params))
new_params = OrderedDict()
for key in params.keys():
new_params[key] = params[key] + 1.0
model.load_state_dict(new_params)
print(model_sum(params))
And the output is
-0.15531682968139648
219.84470748901367
which means params is like a pointer to the state_dict of model. I am curious about why the output is like this?