Loss funciton of a network with two decoders sharing a same encoder?

I have a network that includes one encoder and two decoders for semantic image segmentation. The two decoders share the same encoder. I’m confused about what the loss function would look like. Any help/suggestion/advice would be greatly appreciated.

You can simply calculate two losses from the two decoder output separately, add the two losses together, and perform backpropagation. PyTorch will automatically accumulate the gradient on the shared encoder. Maybe something like:

feature_map = encoder(image)
out1 = decoder1(feature_map)
out2 = decoder2(feature_map)

loss1 = loss_fn1(out1, label1)
loss2 = loss_fn2(out2, label2)
loss = loss1 + loss2
loss.backward()

optimizer.step()
optimizer.zero_grad()

Thank you @Dazitu616. That helps. I actually have a follow-up question – sorry I wasn’t super clear with my question earlier:

How does your solution change if the two decoders get trained together simultaneously, but each of which on a different dataset? I.e. decoder-1 gets trained on trainset-1 and decoder-2 gets trained on trainset-2.

Then maybe you should forward data1 and data2 through encoder and decoder1/2 separately? Like:

feat1 = encoder(data1)
feat2 = encoder(data2)
out1 = decoder1(feat1)
out2 = decoder2(feat2)
# the remaining code is the same
1 Like