Hello everyone…! I’ve been searching all around the internet, but I’m not sure how to implement this situation. I am stuck with this problem for weeks and so confused.
Assume a model with partA and partB.
And we have data1 and data2.
My goal is to train partA and partB with data1 and train partB with data2 simultaneously.
This is a pseudo code for what I implemented in pytorch.
def train():
model.train()
for input, target in data1:
loss1=criterion(model(input),target)
optimizer.zero_grad()
loss1.backward()
input, target in data2:
model.partA.requires_grad=False
loss2=criterion(model(input),target)
loss2.backward()
model.partA.requires_grad=True
optimizer.step()
The iteration with data1 and data2 is simultaneously repeated.
Is my implementation reasonable? IF NOT, can you guys give me an advice about training a model’s separate parts?
Thx for all of you.
The problem I see in the above pseudocode is you’re trying to use the same optimizer and then calling .zero_grad()
on it between data1 and data2. That means your gradients from data1 are getting wiped before you call .step()
.
Oh, my mistake.
I accidentally wrote another ‘optimizer.zero_grad’ during the copy of code.
I edited the pseudocode.
Thanks.
Here is a sample code you can run to see what happens:
import torch
import torch.nn as nn
batch_size=10
model = nn.Sequential(nn.Linear(2,5), nn.Linear(5,1))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
data1 = torch.rand(batch_size, 2)
data2 = torch.rand(batch_size, 2)
targ1 = torch.rand(batch_size, 1)
targ2 = torch.rand(batch_size, 1)
loss1 = criterion(model(data1), targ1)
optimizer.zero_grad()
loss1.backward()
for param in model.parameters():
print(param.grad)
model[0].requires_grad_(False)
loss2 = criterion(model(data2), targ2)
loss2.backward()
model[0].requires_grad_(True)
for param in model.parameters():
print(param.grad)
optimizer.step()
As expected, the first Linear layer gradients do not change during the entry of data2, yet the second layer does. Hope it helps.
Thank you very much.
Since I am a newbie in pytorch, this helped me a lot.
1 Like