Modifying a model's layer weights via state_dict

If you don’t create a deepcopy of the state_dict, it should work:

model = models.resnet18()
sd = model.state_dict()
sd['fc.weight'].zero_()
print(model.fc.weight)
> Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)

However, the better approach would be probably to zero out the parameters directly in the model by wrapping it in a with torch.no_grad() statement and manipulating the parameters as you wish.

2 Likes