Freeze last layers of the model

Hi everyone,

I am trying to implement VGG perceptual loss in pytorch and I have some problems with autograd. Specificly, the output of my network (1) will to through VGG net (2) to calculate features. Also, my ground truth images also go through VGG net to calculate features too. After that, I want to calculate loss based on these features. However, I think that the VGG net is pretrained and it shouldn’t update its gradient. I try to set require_grad = False for every layer in the VGG net and I get an error before training “element 0 of tensors does not require grad and does not have a grad_fn”. I suspect that freezing last layers in the network prevent other layers from calculating gradients too. Is there anyway to bypass it or Should I just let the VGG train with other parameters. Thank all for helps.

If you are setting requires_grad = False for all parameters, the error message is expected, as Autograd won’t be able to calculate any gradients, since no parameter requires them.

Could you describe your use case a bit, i.e. which layers would you like to train and which should be frozen?

I have an generator with grad that translate an image into different image of same shape. I pass the output of the generator to vggnet(requires_grad=False) (1) and I also pass ground truth image into vggnet(requires_grad=False) (2). At the end, I want to calculate l1 loss between (1) and (2). Thanks.

This workflow should work as shown in this dummy example:

# Setup
generator = nn.Conv2d(3, 3, 3, 1, 1)
dummy = nn.Conv2d(3, 3, 3, 1, 1)

data = torch.randn(1, 3, 24, 24)
target = torch.randn(1, 3, 24, 24)

# Freeze dummy model
for param in dummy.parameters():
    param.requires_grad_(False)

# Forward pass for data
out = generator(data)
out = dummy(out)

# Forward pass for target
target = dummy(target)

# Loss calculation
loss = ((out - target)**2).sqrt().mean()
loss.backward()

# Check gradients
print(generator.weight.grad)
print(dummy.weight.grad)

As you can see, generator will get valid gradients, while dummy won’t.

Could you post a code snippet to reproduce this issue, so that we could have a look?

1 Like