How to turn off requires_grad for individual params?

I know it’s trivial for a parameter vector, but when I iterate through the params in, for example:

torch.cat([param.flatten() for param in self.policy.model.parameters()])

setting requires_grad to false returns an error. How do I turn off the gradient for the individual param scalar weights?

Hi,

Could you share a code sample of what you’re trying to do exactly and what is the exact error please?

I have a model and I want to disable some of the weights, so I flatten the weights of the model in order to iterate through them and turn off requires_grad for a subset:

params = torch.cat([param.flatten() for param in model.parameters()])
for i, param in enumerate(params):
     if should_be_disabled[i]:
           param.requires_grad_(False)

This returns error that requires_grad can only be changed for leaves.

Hi,

Your first line actually concatenates all the parameters in a single big Tensor in a differentiable manner.
So the new Tensor params requires_grad property is independent from the one in your parameters.
You can disable gradients for the Tensors in model.parameters() though:

for j, p in enumerate(model.parameters()):
    if should_be_disabled[j]:
        p.requires_grad_(False)
2 Likes

The problem is that p is a vector in your example and I would like to disable individual scalar weights within that vector. (A subset of them)

I am afraid that this is not possible. Tensors are “elementary” autograd objects. And so either the whole Tensor requires gradients or not.

Note that you can just zero-out the gradients after they are computed if you just want to not have gradients for some entries in there. (you can even do that with a hook to make sure it happens every time a gradient is computed for that Tensor).

That would still be good. Thanks! Could you provide an example of this please?

My only concern is that I do three separate backwards passes for three separate loss terms, and I’m worried it’ll get convoluted because each one requires different gradients to be zeroed out.

Sure

def get_hook(param_idx):
   def hook(grad):
        grad = grad.clone() # NEVER change the given grad inplace
        # Assumes 1D but can be generalized
        for i in grad.size(0):
            if should_be_disabled[param_idx][i]:
                grad[i] = 0
        return grad
    return hook

for j, p in enumerate(model.parameters()):
    p.register_hook(get_hook(j))
    
1 Like