How to train overlapping model parts with different datasets

Hello everyone…! I’ve been searching all around the internet, but I’m not sure how to implement this situation. I am stuck with this problem for weeks and so confused.

Assume a model with partA and partB.
And we have data1 and data2.

My goal is to train partA and partB with data1 and train partB with data2 simultaneously.

This is a pseudo code for what I implemented in pytorch.

def train():
    model.train()
    for input, target in data1:
        loss1=criterion(model(input),target)
        optimizer.zero_grad()
        loss1.backward()
        
        input, target in data2:
        model.partA.requires_grad=False
        loss2=criterion(model(input),target)
        loss2.backward()
        model.partA.requires_grad=True
        
        optimizer.step()

The iteration with data1 and data2 is simultaneously repeated.

Is my implementation reasonable? IF NOT, can you guys give me an advice about training a model’s separate parts?

Thx for all of you.

The problem I see in the above pseudocode is you’re trying to use the same optimizer and then calling .zero_grad() on it between data1 and data2. That means your gradients from data1 are getting wiped before you call .step().

Oh, my mistake.

I accidentally wrote another ‘optimizer.zero_grad’ during the copy of code.

I edited the pseudocode.

Thanks.

Here is a sample code you can run to see what happens:

import torch
import torch.nn as nn

batch_size=10

model = nn.Sequential(nn.Linear(2,5), nn.Linear(5,1))

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

criterion = nn.MSELoss()

data1 = torch.rand(batch_size, 2)
data2 = torch.rand(batch_size, 2)

targ1 = torch.rand(batch_size, 1)
targ2 = torch.rand(batch_size, 1)

loss1 = criterion(model(data1), targ1)
optimizer.zero_grad()
loss1.backward()
for param in model.parameters():
    print(param.grad)
model[0].requires_grad_(False)
loss2 = criterion(model(data2), targ2)
loss2.backward()
model[0].requires_grad_(True)
for param in model.parameters():
    print(param.grad)

optimizer.step()

As expected, the first Linear layer gradients do not change during the entry of data2, yet the second layer does. Hope it helps.

Thank you very much.

Since I am a newbie in pytorch, this helped me a lot.

1 Like