If you don’t create a deepcopy of the state_dict
, it should work:
model = models.resnet18()
sd = model.state_dict()
sd['fc.weight'].zero_()
print(model.fc.weight)
> Parameter containing:
tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], requires_grad=True)
However, the better approach would be probably to zero out the parameters directly in the model by wrapping it in a with torch.no_grad()
statement and manipulating the parameters as you wish.