Problem training model parallel net

Hi,
with this model:

features = 16*16
# define a simple linear VAE
class LinearAECenc(nn.Module):
    def __init__(self):
        super(LinearAECenc, self).__init__()
 
        # encoder
        self.enc1 = nn.Linear(in_features=V*H, out_features=512*3//2)
        self.enc12 = nn.Linear(in_features=512*3//2, out_features=512)
        self.enc2 = nn.Linear(in_features=512, out_features=features)


    def forward(self, x):
        # encoding
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc12(x))
        z = F.relu(self.enc2(x))
        return z

class LinearAECdec(nn.Module):
    def __init__(self):
        super(LinearAECdec, self).__init__()
        # decoder 
        self.dec1 = nn.Linear(in_features=features, out_features=512)
        self.dec12 = nn.Linear(in_features=512, out_features=512*3//2)
        self.dec2 = nn.Linear(in_features=512*3//2, out_features=V*H)

 
    def forward(self, z):

        # decoding
        y = F.relu(self.dec1(z))
        y = F.relu(self.dec12(y))
        reconstruction = torch.sigmoid(self.dec2(y))
        return reconstruction

class AEC(nn.Module):
    def __init__(self):
        super(AEC, self).__init__()
        self.enc=LinearAECenc().to(device0)
        self.dec=LinearAECdec().to(device1)

    def forward(self,x):
        z=self.enc.forward(x.to(device0))
        rec=self.dec.forward(z.to(device1))
        return rec

and this trinings function

model = AEC() #.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss(reduction='sum')
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        reconstruction = model(data)
        loss = criterion(reconstruction, data.to(device1))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

optimizer = optim.Adam(model.parameters(), lr=lr)
epochs=5000
for epoch in (p:=tqdm(range(epochs))):
    train_epoch_loss = fit(model, train_s3load)
    train_loss.append(train_epoch_loss)
    p.set_description(f"Loss: {train_loss[-1]:.4f}  ")

Only the parameters on device1 are updated. Can you give me hint, what to do to update all
parameters?

Thanks in advance

Could you give a little more context on how you are performing distributed training?

From what I understand, the above piece of code trains model only on a single GPU if you make model = AEC() #.to(device) to model = AEC().to(device). To train on multiple GPU, something like torch.distributed must be used.

The model is running on ones machine with two GPUS (device0, device1). The idea is to run the encoder on one device (device0) and the decoder on the second device (device1). The code is based on this tutorial: [Single-Machine Model Parallel Best Practices — PyTorch Tutorials 2.0.1+cu117 documentation].