The values in the model parameters won’t be changed, if you assign a new tensor to the key in the state_dict
.
You could either load the manipulated state_dict
afterwards or change the parameter’s value inplace as shown here:
model = nn.Linear(1, 1)
print(model.weight)
> Parameter containing:
tensor([[0.8777]], requires_grad=True)
sd = model.state_dict()
sd['weight'] = torch.tensor([[1.]])
print(model.weight)
> Parameter containing:
tensor([[0.8777]], requires_grad=True)
model.load_state_dict(sd)
print(model.weight)
> Parameter containing:
tensor([[1.]], requires_grad=True)
# or
model = nn.Linear(1, 1)
print(model.weight)
> Parameter containing:
tensor([[-0.8112]], requires_grad=True)
with torch.no_grad():
sd = model.state_dict()
sd['weight'].fill_(1.)
print(model.weight)
> Parameter containing:
tensor([[1.]], requires_grad=True)