Mechanics of the optimizer momentum when updating a conditional neural network module

Question is after code block. My module has two parts, partA and partB which are triggered conditionally in the forward function:

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()

    self.partA = nn.Sequential(

    self.partB = nn.Sequential(

    def forward(input_tensor):
        if conditionB:
            return self.partB(input_tensor)
        elif conditionA: # for clarity. could just use 'else' here.
            return self.partB(self.partA(input_tensor))

## initialize and train
model = Network().to(device)
opt = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

for epoch in range(num_epochs):
    for i, data in enumerate(trainloader):
        outputA = model(input_tensor) # condition A
        err = criterion(labels, outputA)

        outputB = model(input_tensor) # condition B
        err = criterion(labels, outputB)

During training, when condition B is triggered, there will be gradients through partB of the network but not partA. However, will the momentum term in SGD cause an update in partA when opt.step() is called, even though it was not involved in the computational graph during the condition B step? If so, how should I avoid this? Give partA and partB separate optimizers?

1 Like

I would indeed try to use separate optimizers here.
I think you could also replace for p in model.parameters(): p.grad = None to see if that skips the momentum update, but this is less obvious to the user so the next change might break it again.

Best regards


Thanks Thomas! I separated them into multiple optimizers and it seems to work.