Freeze model weights only on the second call of layer

I’m currently implementing a network architecture where I have to pass my data through a certain layer for multiple times.
The network can be simplified as:

Input ->  Layer A -> Layer B -> Layer A -> Output

However I dont want Layer A to be updated by the gradient, i.e. freezed, on the latter call, while still be updated by the first call. Is this possible with pytorch? I’ve tried to come up with something using no_grad, requires_grad or detach() but I couldn’t.

A hacky workaround Im thinking of is to do gradient scaling during the backward pass:

Input Gradients <- Layer A <- Layer B <- Scale up gradient by λ <- Layer A <- Scale down gradient by λ <- Output Gradients

Is there a cleaner implementation? Thanks.

Wouldn’t torch.no_grad() work?

x = layerA(x)
x = layerB(x)
with torch.no_grad():
    x = layerA(x)

loss = criterion(x, target)

This would make sure to not track the second call on layerA.

If this is the case, I may have misunderstood how no_grad works. I was under the impression that with no_grad, the gradient would not be computed for the second call of layerA and thus the gradient cannot be propagated to layerB and also the first call of LayerA. Can you confirm this is not the case? Thanks for the quick reply.

No, sorry, I’m wrong and the output won’t require gradients.

In that case, you could detach the parameters of layerA and use the functional API, which should work:

torch.manual_seed(2809)

layerA = nn.Linear(10, 10)
layerB = nn.Linear(10, 10)

x = torch.randn(1, 10)

out = layerA(x)
out = layerB(out)

out = F.linear(out, layerA.weight.detach(), layerA.bias.detach())

out.mean().backward()

print(layerA.weight.grad)

In my case, my layerA is much more complicated than a nn.Linear layer, so simply detaching weightings would not be practical. I am thinking of deepcopying layerA instead, however im not familiar with the interactions of it with pytorch. Would you know if this will work?

from copy import deepcopy
import torch

x = torch.randn(1, 10)
layerA = ...
layerB = ...
layerA_2 = deepcopy(layerA)

out = layerA(x)
out = layerB(out)
out = layerA_2(out)

out.mean().backward()


Yes, your approach should also work.
I compared it to the functional API run and get the same gradients for the copied module.