I want to reset the gradients of certain weights. This code works at the layer level:
What I want to do is, set the requires_grad to false of certain weights. Currently I am doing it in a very crude way, of manually setting the gradients to 0 after .backward() as below:
No, you cannot turn off
requires_grad = True just for certain elements
of a pytorch tensor.
requires_grad applies to the entire tensor as a whole.
How best to do this depends on your actual goal.
If what you want is to “freeze” certain elements of your weight tensor – that
is, have those elements not change when you call
opt.step() – then it is
safer to save the old values, call
opt.step() (which might modify them),
and then restore the saved values (rather than setting their gradients to
zero before calling
This is because if you are using weight decay (or potentially momentum),
opt.step() can modify elements for which the gradient is zero.