Hi, I got confused about how to compute gradients and do back propagation by a skipping way.
For example, I have a net as follows
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.A = ModuleA()
self.B = ModuleB()
def forward(self, x, y):
a = self.A(x)
b = self.B(a)
c = self.B(y)
return b, c
loss = Loss(b) + Loss(c)
optimizer.zero_grad()
loss.backward()
optimizer.step()
When execute backward and step, Loss(c)
will lead to updating weights of ModuleB
, that’s OK. But at the same time, what if I want to update weights of ModuleA
without updating weights of ModuleB
according to Loss(b)
?
I tried to use
with torch.no_grad():
b = self.B(a)
but it doesn’t look like a reasonable solution for my problem.
Are there any other better solutions? Thanks a lot!