Does optimizer skip update of parameters with requires_grad=False?

Do all PyTorch optimizers ignore parameters with requires_grad=False even if these parameters are explicitly passed during the instantiation of optimizer?

From the source code of optimizer.py, it can be seen the zero_grad does not have any effect if p.grad is not None:

for p in group["params"]:
    if p.grad is not None:
        if set_to_none:
            p.grad = None
        else:
            if p.grad.grad_fn is not None:
                p.grad.detach_()
            else:
                p.grad.requires_grad_(False)
            if not foreach or p.grad.is_sparse:
                p.grad.zero_()
            else:
                assert per_device_and_dtype_grads is not None
                per_device_and_dtype_grads[p.grad.device][
                    p.grad.dtype
                ].append(p.grad)

Moreover, the step method (shown for Adam as example) doesn’t also have an effect if p.grad is None:

def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    """
    loss = None
    if closure is not None:
        loss = closure()


    for group in self.param_groups:
        for p in group['params']:
            if p.grad is None:
                continue

If that is the case for all optimizers, then setting mymodel.param.requires_grad_(False) would be enough for ensuring that the model parameters are not updated by the optimizer. Can someone from PyTorch dev team verify that this is the general behavior?

No, it’s generally not enough since the optimizer does not check for param.requires_grad but for the actual .grad attribute.

Even if the parameter is frozen (via requires_grad = False) it can still have a valid .grad attribute (including zeros) and might be updated. Here is a simple example:

lin = nn.Linear(1, 1, bias=False)
optimizer = torch.optim.Adam(lin.parameters(), lr=1.)
x = torch.randn(1, 1)

# update
print(lin.weight)
# Parameter containing:
# tensor([[-0.8696]], requires_grad=True)
out = lin(x)
out.mean().backward()
optimizer.step()
print(lin.weight)
# Parameter containing:
# tensor([[-1.8696]], requires_grad=True)
optimizer.zero_grad()

# freeze
lin.weight.requires_grad = False
out = lin(x)
# out.mean().backward() # would fail
print(lin.weight.grad)
# None
# manually set the .grad
lin.weight.grad = torch.full_like(lin.weight, 1000.)
optimizer.step()
print(lin.weight)
# Parameter containing:
# tensor([[-2.6148]])

# Adam has runnin internal stats and will also update the param with a zero gradient
lin.weight.grad = torch.zeros_like(lin.weight)
optimizer.step()
print(lin.weight)
# Parameter containing:
# tensor([[-3.1908]])

# setting the .grad to None will skip the update
optimizer.zero_grad(set_to_none=True) # set_to_none=True is the default
optimizer.step()
print(lin.weight)
# Parameter containing:
# tensor([[-3.1908]])

@ptrblck Thanks for the answer and the examples! In my mind I had the “typical scenario” where the user doesn’t interact with the .grad attribute manually after it has been set requires_grad=False (e.g. it freezes part of the model). For example:

import torch
from torch.nn import Linear
from torch.optim import Adam
from torch.utils.data import DataLoader

model = Linear(1, 1)
optimizer = Adam(model.parameters(), weight_decay=0.1)
x = torch.randn(10, 1)
loader = DataLoader(x, batch_size=2)

requires_grad = False
model.bias.requires_grad_(requires_grad)
print(f'\Requires grad: {requires_grad}\n')

print('Before training')
print(list(model.named_parameters()))

for batch in loader:
    optimizer.zero_grad()
    out = model(x).sum()
    out.backward()
    optimizer.step()

print('After training')
print(list(model.named_parameters()))
for name, p in model.named_parameters():
    print(name, p.grad)

I think the safest choice is to filter the parameters passed to the optimizer?

optimizer = Adam(filter(lambda p: p.requires_grad, parameters()))

Yes, typical scenarios might work as you’ve described. I’m just pointing out limitations for potentially hard-to-debug issues.

Yes, I’m an advocate for an explicit usage of all APIs avoiding potential side-effects.