This is the key part in configure_optimizers() function in pytorch lightning:
params = []
train_names = []
print("Training only unet attention layers")
for name, module in self.model.diffusion_model.named_modules():
if isinstance(module, CrossAttention) and name.endswith('attn2'):
train_names.append(name)
params.extend(module.parameters())
else:
# Set requires_grad=False for all other parameters
for param in module.parameters():
param.requires_grad = False
opt = torch.optim.AdamW(params, lr=lr)