I’m currently implementing a network architecture where I have to pass my data through a certain layer for multiple times.
The network can be simplified as:
Input -> Layer A -> Layer B -> Layer A -> Output
However I dont want Layer A to be updated by the gradient, i.e. freezed, on the latter call, while still be updated by the first call. Is this possible with pytorch? I’ve tried to come up with something using no_grad, requires_grad or detach() but I couldn’t.
A hacky workaround Im thinking of is to do gradient scaling during the backward pass:
Input Gradients <- Layer A <- Layer B <- Scale up gradient by λ <- Layer A <- Scale down gradient by λ <- Output Gradients
If this is the case, I may have misunderstood how no_grad works. I was under the impression that with no_grad, the gradient would not be computed for the second call of layerA and thus the gradient cannot be propagated to layerB and also the first call of LayerA. Can you confirm this is not the case? Thanks for the quick reply.
In my case, my layerA is much more complicated than a nn.Linear layer, so simply detaching weightings would not be practical. I am thinking of deepcopying layerA instead, however im not familiar with the interactions of it with pytorch. Would you know if this will work?
from copy import deepcopy
import torch
x = torch.randn(1, 10)
layerA = ...
layerB = ...
layerA_2 = deepcopy(layerA)
out = layerA(x)
out = layerB(out)
out = layerA_2(out)
out.mean().backward()