I want to alternately update two networks, say m1
and m2
. Both these networks are a part of another nn.Module
. For simplicity, we can assume that both m1
and m2
are linear layers.
Currently, I have implemented this as shown below, however my loss remains nearly constant which indicates that gradients aren’t being updated correctly.
optimize_m1 = True
for x1, x2 in train_loader:
m1_optimizer.zero_grad()
m2_optimizer.zero_grad()
loss = model(x1, x2, optimize_m1=optimize_m1)
loss.backward()
if optimize_m1:
m1_optimizer.step()
else:
m2_optimizer.step()
optimize_m1 = not optimize_m1
Model is defined as below:
class Model(nn.Module):
def __init__(self, in_dim, out_dim
) -> None:
super().__init__()
self.m1 = nn.Linear(in_dim, out_dim)
self.m2 = nn.Linear(in_dim, out_dim)
self.act = Activation()
def forward(self, x1, x2, optimize_online=True):
out1 = self.m1(x1)
out2 = self.m2(x2)
out1 = self.act(out1)
out2 = self.act(out2)
loss = mse(out1, out2)
return loss
For optimizers, I am using SGD defined as
m1_optimizer = torch.optim.SGD(list(model.m1.parameters()), lr=config.lr)
m2_optimizer = torch.optim.SGD(list(model.m2.parameters()), lr=config.lr)
Also, for loss do I need to use .detach()
on out1
and out2
alternately, while optimizing the model? I am pretty sure that this simple architecture works because I replicated this with linear layers in numpy and the code works correctly. Trying to figure out where am I going wrong with PyTorch implementation.