Correct way to do alternating updates with multiple optimizers

I want to alternately update two networks, say m1 and m2. Both these networks are a part of another nn.Module. For simplicity, we can assume that both m1 and m2 are linear layers.

Currently, I have implemented this as shown below, however my loss remains nearly constant which indicates that gradients aren’t being updated correctly.

        optimize_m1 = True
        for x1, x2 in train_loader:
            m1_optimizer.zero_grad()
            m2_optimizer.zero_grad()

            loss = model(x1, x2, optimize_m1=optimize_m1)
            loss.backward()

            if optimize_m1:
                m1_optimizer.step()
            else:
                m2_optimizer.step()

            optimize_m1 = not optimize_m1

Model is defined as below:

class Model(nn.Module):
    def __init__(self, in_dim, out_dim
                    ) -> None:
        
        super().__init__()
        self.m1 = nn.Linear(in_dim, out_dim)
        self.m2 = nn.Linear(in_dim, out_dim)
       
        self.act = Activation()

    def forward(self, x1, x2, optimize_online=True):
        out1 = self.m1(x1)
        out2 = self.m2(x2)

        out1 = self.act(out1)
        out2 = self.act(out2)
        
        loss = mse(out1, out2)
        return loss

For optimizers, I am using SGD defined as

m1_optimizer = torch.optim.SGD(list(model.m1.parameters()), lr=config.lr)
m2_optimizer = torch.optim.SGD(list(model.m2.parameters()), lr=config.lr)

Also, for loss do I need to use .detach() on out1 and out2 alternately, while optimizing the model? I am pretty sure that this simple architecture works because I replicated this with linear layers in numpy and the code works correctly. Trying to figure out where am I going wrong with PyTorch implementation.

Since you are calling the step() method for the corresponding optimizer only in each iteration, you wouldn’t need to detach tensors.

You could verify this claim by checking the .grad attribute of the model, which will be updated, and check that it was indeed updated properly in the right iteration.

1 Like