Hello everyone…! I’ve been searching all around the internet, but I’m not sure how to implement this situation. I am stuck with this problem for weeks and so confused.

Assume a model with partA and partB.

And we have data1 and data2.

My goal is to train partA and partB with data1 and train partB with data2 simultaneously.

This is a pseudo code for what I implemented in pytorch.

```
def train():
model.train()
for input, target in data1:
loss1=criterion(model(input),target)
optimizer.zero_grad()
loss1.backward()
input, target in data2:
model.partA.requires_grad=False
loss2=criterion(model(input),target)
loss2.backward()
model.partA.requires_grad=True
optimizer.step()
```

The iteration with data1 and data2 is simultaneously repeated.

Is my implementation reasonable? IF NOT, can you guys give me an advice about training a model’s separate parts?

Thx for all of you.

The problem I see in the above pseudocode is you’re trying to use the same optimizer and then calling `.zero_grad()`

on it between data1 and data2. That means your gradients from data1 are getting wiped before you call `.step()`

.

Oh, my mistake.

I accidentally wrote another ‘optimizer.zero_grad’ during the copy of code.

I edited the pseudocode.

Thanks.

Here is a sample code you can run to see what happens:

```
import torch
import torch.nn as nn
batch_size=10
model = nn.Sequential(nn.Linear(2,5), nn.Linear(5,1))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
data1 = torch.rand(batch_size, 2)
data2 = torch.rand(batch_size, 2)
targ1 = torch.rand(batch_size, 1)
targ2 = torch.rand(batch_size, 1)
loss1 = criterion(model(data1), targ1)
optimizer.zero_grad()
loss1.backward()
for param in model.parameters():
print(param.grad)
model[0].requires_grad_(False)
loss2 = criterion(model(data2), targ2)
loss2.backward()
model[0].requires_grad_(True)
for param in model.parameters():
print(param.grad)
optimizer.step()
```

As expected, the first Linear layer gradients do not change during the entry of data2, yet the second layer does. Hope it helps.

Thank you very much.

Since I am a newbie in pytorch, this helped me a lot.

1 Like