Optimizer step only for a specific group of parameters

I have an Adam Optimizer initialized as a dict where I set different lr for two separates groups of parameters.

During the forward pass I compute the L1 Loss and a distillation loss and I would like to optimizer only group 0 for the L1 Loss gradients and only group 1 for the Distillation Loss.

Can I do something like this?

opt.zero_grad()
opt.param_groups[0].requires_grad_(True)
opt.param_groups[1].requires_grad_(False)
self.manual_backward(step_out["loss"])
opt.step()

distillation_loss = self.flow_distillation(step_out["flow_f"], step_out["hr"])

if self.reverse:
    distillation_loss += self.flow_distillation(step_out["flow_b"], step_out["hr"], reverse=True)

step_out["distillation_loss"] = distillation_loss
step_out["loss"] += distillation_loss

opt.zero_grad()
opt.param_groups[1].requires_grad_(True)
opt.param_groups[0].requires_grad_(False)
self.manual_backward(step_out["distillation_loss"])
opt.step()

No, the param_groups argument does not hold a .requires_grad attribute, but references to the parameters, the learning rate etc.
While you could set the .requires_grad attribute of the actual parameters to True/False, I would probably just use two separate optimizers as it seems to be the cleaner approach.

Yes the problem is that I am using the DeepSpeed strategy and it doesn’t support more than one optimizer, otherwise I would do just like you suggested.
In the first case is it correct the order of the operations?

# this is just pseudo-code
for p in opt.param_groups[0]: p.requires_grad_(True)
for p in opt.param_groups[1]: p.requires_grad_(False)

opt.zero_grad()
self.manual_backward(step_out["loss"])
opt.step()

...

# this is just pseudo-code
for p in opt.param_groups[0]: p.requires_grad_(False)
for p in opt.param_groups[1]: p.requires_grad_(True)

opt.zero_grad()
self.manual_backward(step_out["distillation_loss"])
opt.step()

Or should I zero the gradients only one time? Thank you for your time @ptrblck

Ah OK, that’s an unfortunate limitation.
Your approach might work as seen in this small example:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 1)
        self.fc2 = nn.Linear(1, 1)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x
    
model = MyModel()
optimizer = torch.optim.Adam([
    {"params": model.fc1.parameters(), "lr": 1.},
    {"params": model.fc2.parameters(), "lr": 1e-1}
])

x = torch.randn(1, 1)
out = model(x)
loss1 = out.mean()
loss2 = out.sum()

for param in optimizer.param_groups[1]["params"]:
    param.requires_grad = False
    
loss1.backward(retain_graph=True)
print([(name, p.grad) for name, p in model.named_parameters()])
# [('fc1.weight', tensor([[-0.3356]])), ('fc1.bias', tensor([-0.8679])), ('fc2.weight', None), ('fc2.bias', None)]
optimizer.step()

optimizer.zero_grad()
for param in optimizer.param_groups[1]["params"]:
    param.requires_grad = True
for param in optimizer.param_groups[0]["params"]:
    param.requires_grad = False
loss2.backward()
print([(name, p.grad) for name, p in model.named_parameters()])
# [('fc1.weight', None), ('fc1.bias', None), ('fc2.weight', tensor([[0.1324]])), ('fc2.bias', tensor([1.]))]
optimizer.step()

But be careful and double check if the optimizer is indeed working as expected.
In particular, note that some optimizers, such as Adam, use internal running stats to update the parameters. Even with a zero gradient the parameter would thus be updated.
Calling optimizer.zero_grad() will delete the .grad attribute in newer PyTorch releases as the set_to_none argument is set to True by default starting in torch==2.0.0.
While this is a performance improvement it also makes sure the optimizer is not able to update any parameters with “zero gradients” (which is in fact “deleted gradients” now).
Your opt.zero_grad() usage thus looks correct to me, but I would still double check in a small test script that no parameters are updated when they should be frozen.
Especially if you are using an older PyTorch release use opt.zero_grad(set_to_none=True) to delete the gradients as otherwise the second opt.step() will also update the previously updated parameters if your optimizer uses internal running stats.

1 Like