This is the key part in configure_optimizers() function in pytorch lightning:
params = []
train_names = []
print("Training only unet attention layers")
for name, module in self.model.diffusion_model.named_modules():
if isinstance(module, CrossAttention) and name.endswith('attn2'):
train_names.append(name)
params.extend(module.parameters())
else:
# Set requires_grad=False for all other parameters
for param in module.parameters():
param.requires_grad = False
opt = torch.optim.AdamW(params, lr=lr)
What I want to do here is only setting attention parameters requires_grad = True and passing them to the optimizer.
It did work if I set all the parameters with requires_grad = True.
I just found out that this error seems to be caused by the custom-defined class CheckpointFunction(torch.autograd.Function)
and I just made the code work by simply avoiding using this class
Hi, Yuanzhi, I met the same problem when I set the ‘attention layer’ trainable only. Could you clarify how to change the CheckpointFunction(torch.autograd.Function)?
I’m facing the same issue here. The reason CheckpointFunction do not work is that it will try to get the grads of parameters that requires_grad=False.
For those who want to use CheckpointFunction in this case, you can modify the backward function like this:
class CheckpointFunction(torch.autograd.Function):
.......
@staticmethod
def backward(ctx, *output_grads):
......
grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + [x for x in ctx.input_params if x.requires_grad],
output_grads,
allow_unused=True,
)
grads = list(grads)
# Assign gradients to the correct positions, matching None for those that do not require gradients
input_grads = []
for tensor in ctx.input_tensors + ctx.input_params:
if tensor.requires_grad:
input_grads.append(grads.pop(0)) # Get the next computed gradient
else:
input_grads.append(None) # No gradient required for this tensor
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + tuple(input_grads)