Detached weights still being updated

Hi, I’m building a model that will freeze some layers after several training steps. But the weights which should be frozen are still updated.
Here’s a piece of code to reproduce the issue. The printed self.conv1.weight keeps changing even though it’s already detached

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
        self.conv2 = nn.Conv2d(8, 8, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(8, 2)
        self.loss = nn.CrossEntropyLoss(reduction="mean")
    def forward(self, x, y, step):
        if step < 5:

        if step < 5:
            x = self.conv1(x)
            with torch.no_grad():
                x = self.conv1(x).detach()
            x = self.conv2(x)
        x = self.pool(x)[:, :, 0, 0]
        x = self.linear(x)
        print(step, self.conv1.weight.mean())
        loss = self.loss(x, y)
        return loss
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

for i in range(10):
    input_data = torch.rand(size=[2, 3, 16, 16])
    target_data = (torch.rand(size=[2]) * 2).long()
    loss = model(input_data, target_data, i)

Hi, your code doesn’t do what you require.
Set requires_grad=True for the parameters that you want frozen.

I am not an expert, but this might be helpful in debugging the problem.
when using just:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

you get what you expected, so maybe there is something happening when using
momentum, weight_decay

You are right and weight_decay should still update the parameters even with a zero gradient.
@gathierry you could disable this behavior by using optimizer.zero_grad(set_ton_none=True) which will delete the .grad attribute and thus skip the parameter update. In newer PyTorch version set_to_none is True by default.

1 Like

Thanks, set_ton_none=True can solve my problem. But I guess the change is because of momentum?
I tried weight_decay=0, but the weights are still updated. For momentum=0, the weight turns out to be consistent.

Thanks. Spent a lot of time trying to figure out why they were still getting updated.