I have a model and I want to disable some of the weights, so I flatten the weights of the model in order to iterate through them and turn off requires_grad for a subset:
params = torch.cat([param.flatten() for param in model.parameters()])
for i, param in enumerate(params):
if should_be_disabled[i]:
param.requires_grad_(False)
This returns error that requires_grad can only be changed for leaves.
Your first line actually concatenates all the parameters in a single big Tensor in a differentiable manner.
So the new Tensor params requires_grad property is independent from the one in your parameters.
You can disable gradients for the Tensors in model.parameters() though:
for j, p in enumerate(model.parameters()):
if should_be_disabled[j]:
p.requires_grad_(False)
I am afraid that this is not possible. Tensors are “elementary” autograd objects. And so either the whole Tensor requires gradients or not.
Note that you can just zero-out the gradients after they are computed if you just want to not have gradients for some entries in there. (you can even do that with a hook to make sure it happens every time a gradient is computed for that Tensor).
That would still be good. Thanks! Could you provide an example of this please?
My only concern is that I do three separate backwards passes for three separate loss terms, and I’m worried it’ll get convoluted because each one requires different gradients to be zeroed out.
def get_hook(param_idx):
def hook(grad):
grad = grad.clone() # NEVER change the given grad inplace
# Assumes 1D but can be generalized
for i in grad.size(0):
if should_be_disabled[param_idx][i]:
grad[i] = 0
return grad
return hook
for j, p in enumerate(model.parameters()):
p.register_hook(get_hook(j))