Joint Training Dicussion

Hi supposed I have two networks,
unet → for segmentation
resnet → for classification

How do I do joint training? is the code below correct?


def train_models(inputs, targets, labels):
    unet.train()
    resnet.train()
    segment_output = unet(inputs)
    output = resnet(segment_output)
   
    optimizer1.zero_grad()
    optimizer2.zero_grad()
    loss1 = MSE(segmentation_output, labels)
    loss2 = BCE(output, target)
    loss1.backward()
    loss2.backward()

    optimizer1.step()
    optimizer2.step()

Or this one is the correct way?

optimizer = optim.Adam([
    {'params': unet.parameters()},
    {'params': resnet.parameters(), 'lr': 1e-3}
], lr=1e-4)

def train_models(inputs, targets, labels):
    unet.train()
    resnet.train()
    segment_output = unet(inputs)
    output = resnet(segment_output)
   
    optimizer.zero_grad()
    loss1 = MSE(segmentation_output, labels)
    loss2 = BCE(output, target)
    loss = loss1 + loss2
    loss.backward()
    optimizer.step()

If I happened to have weighted loss, is the 2nd approach better?

What do you mean by joint training?
Are the networks totally independent?
Does the sum of losses make sense in any way?
If you just want to train two different models you can run two different scripts. No need to run them in the same script.

Both codes will “run”, the difference is the theory in behind depending on what you are doing.

i want to implement a segmentation based classification. The reason why i want to joint it together is to achieve end-to-end fashion.
First, I will warm up unet
Second, i will warm up resnet, while unet is frozen
Third, i will train both network together with the sum of their loss.

Question is in the third phase, is summing loss and separating loss has the same meaning? I hope that unet learning from the classification network too.

By the way, you said both code will run, so i can say the two implementation above have different meaning?

Yeh then you can sum both losses at the same time, as it will jointly optimize the problem.
Think that classifier’s gradients affect the segmentation too. If you see that classifier info degradates the segmentation then you can train them appart. But in that case it would be better to fully train the segmentator and then the classifier. I don’t find the point of iterate between both.

1 Like