How to compute gradients and do back propagation by a skipping way

Hi, I got confused about how to compute gradients and do back propagation by a skipping way.
For example, I have a net as follows

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.A = ModuleA()
        self.B = ModuleB()

    def forward(self, x, y):
        a = self.A(x)
        b = self.B(a)
        c = self.B(y)
        return b, c

loss = Loss(b) + Loss(c)
optimizer.zero_grad()
loss.backward()
optimizer.step()

When execute backward and step, Loss(c) will lead to updating weights of ModuleB, that’s OK. But at the same time, what if I want to update weights of ModuleA without updating weights of ModuleB according to Loss(b)?
I tried to use

    with torch.no_grad():
        b = self.B(a)

but it doesn’t look like a reasonable solution for my problem.
Are there any other better solutions? Thanks a lot!

1 Like

Yes doing with torch.no_grad() will prevent gradient from flowing back all the way to moduleA.

You can set the requires_grad flag on each parameter to False though to avoid that:

self.B.requires_grad_(False)
b = self.B(a)
self.B.requires_grad_(True)
c = self.B(y)

That will make sure that what is computed using b, won’t update the gradients of module.B.