Can I update two connected models at the same time?

If I update two connected models , loss term become NAN

X = # ...
temp = Model1(X)
out = Model2(temp)
loss = loss_func(out,label)

Model1_optimizer.zero_grad()
Model2_optimizer.zero_grad()

loss.backward()

Model1_optimizer.step()
Model2_optimizer.step()

I also try to use one optimizer which contain 2 models’ weights.

params = list(Model1.parameters()) + list(Model2.parameters())
opt = torch.optim.Adam(params, lr=0.002)
X = # ...
temp = Model1(X)
out = Model2(temp)
loss = loss_func(out,label)

opt.zero_grad()
loss.backward()
opt.step()

I still got NAN loss term.

However, when I update only one of the two, it works well.
But I still want to update both for better performance.
What should I do?

It shouldn’t work at all since you do zero_grad not in the beginning of training loop but right in the middle. Also, since one of your models is inner model, you don’t need to call optimizer on both, call it on just the external one.

# create outer model, using __init__ create instance of inner model and place it to self.basemodel 
model = OuterModel()
# self.basemodel = InnerModel()
# ^^^^ this goes to __init__ of OuterModel
# then in forward() pass of OuterModel you do
# def forward(self,x):
#     x = self.basemodel(x)
#     return x

# create single optimizer and pass model.parameters() to it

# training loop:
    optimizer.zero_grad()
    out = model(x)
    loss = loss_func(out,label)
    loss.backward()
    optimizer.step()

Thanks for replay.

I moved the zero_grad function to the beginning of training loop, but NAN situation has not changed.
For some reason, I need to use two models instead of one big model.
Could you give me some advice?

You are using one big model anyway at least memory-wise. And you can always access inner model directly by calling model.basemodel(x)

as for your question, it should work fine:

import torch
import torch.nn as nn


class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(10, 100)

    def forward(self, x):
        return self.l1(x)


class Model2(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(100, 1)

    def forward(self, x):
        return self.l1(x)


m1 = Model1()
m2 = Model2()

opt1 = torch.optim.SGD(m1.parameters(), lr=1e-3)
opt2 = torch.optim.SGD(m2.parameters(), lr=1e-3)

loss_fn = nn.MSELoss()

batchsize = 32
x = torch.ones((batchsize, 10)).float()
labels = torch.ones((batchsize, 10)).float()

for epoch in range(50):

    opt1.zero_grad()
    opt2.zero_grad()
    y = m1(x)
    z = m2(y)
    loss = loss_fn(z, labels)
    print(f'Epoch: {epoch}, loss: {loss.item():.4f}')
    loss.backward()
    opt1.step()
    opt2.step()

I found that the NAN situation was caused by exploding gradient. After applying the gradient clipping ,it works perfectly.

Thanks for help.