Hi, I got confused about how to compute gradients and do back propagation by a skipping way.
For example, I have a net as follows
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.A = ModuleA() self.B = ModuleB() def forward(self, x, y): a = self.A(x) b = self.B(a) c = self.B(y) return b, c loss = Loss(b) + Loss(c) optimizer.zero_grad() loss.backward() optimizer.step()
When execute backward and step,
Loss(c) will lead to updating weights of
ModuleB, that’s OK. But at the same time, what if I want to update weights of
ModuleA without updating weights of
ModuleB according to
I tried to use
with torch.no_grad(): b = self.B(a)
but it doesn’t look like a reasonable solution for my problem.
Are there any other better solutions? Thanks a lot!