Get RuntimeError: One of the differentiated Tensors does not require grad in pytorch lightning

This is the key part in configure_optimizers() function in pytorch lightning:

params = []
train_names = []
print("Training only unet attention layers")
for name, module in self.model.diffusion_model.named_modules():
    if isinstance(module, CrossAttention) and name.endswith('attn2'):
        train_names.append(name)
        params.extend(module.parameters())
    else:
        # Set requires_grad=False for all other parameters
        for param in module.parameters():
            param.requires_grad = False
opt = torch.optim.AdamW(params, lr=lr)

Guessing without a repro but can you try explicitly setting param.requires_grad = True in your if condition

Hi Marksaroufim,

Thank you for replying.

What I want to do here is only setting attention parameters requires_grad = True and passing them to the optimizer.
It did work if I set all the parameters with requires_grad = True.

I just found out that this error seems to be caused by the custom-defined class CheckpointFunction(torch.autograd.Function)
and I just made the code work by simply avoiding using this class

Hi, Yuanzhi, I met the same problem when I set the ‘attention layer’ trainable only. Could you clarify how to change the CheckpointFunction(torch.autograd.Function)?

I also encountered the same situation when I debugged the code. Did you solve this problem? I look forward to receiving your reply

I’m facing the same issue here. The reason CheckpointFunction do not work is that it will try to get the grads of parameters that requires_grad=False.
For those who want to use CheckpointFunction in this case, you can modify the backward function like this:

class CheckpointFunction(torch.autograd.Function):
     .......
    @staticmethod
    def backward(ctx, *output_grads):
        ......

        grads = torch.autograd.grad(
            output_tensors,
            ctx.input_tensors + [x for x in ctx.input_params if x.requires_grad],
            output_grads,
            allow_unused=True,
        )

        grads = list(grads)

        # Assign gradients to the correct positions, matching None for those that do not require gradients
        input_grads = []
        for tensor in ctx.input_tensors + ctx.input_params:
            if tensor.requires_grad:
                input_grads.append(grads.pop(0))  # Get the next computed gradient
            else:
                input_grads.append(None)  # No gradient required for this tensor

        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + tuple(input_grads)