is it possible to train one unet with a shared encoder and dual decoders to perform two different tasks?
Hi Claudia!
I’ve never tried it and I don’t know off hand of any examples or research, but I think that
this would be possible.
Some commentary:
First, for simplicity, let’s assume that the input for a baseline, single-task U-Net would be a
single-channel gray-scale image and that the output would be a single-channel probabilistic
binary (e.g., logits) prediction “image.”
The simplest approach would be to take your “regular” single-task U-Net and change its
final layer to have two output channels (one for task A and the second for task B), instead
of only one.
In this case almost all of the U-Net is shared between, and jointly trained on, the two tasks,
with only each of the two output channels specific to a given task.
In your proposal, the first half of the U-Net – the “encoder” – is shared and trained jointly,
while the second half – the dual “decoders” – are task specific and trained separately.
The next logical step is to realize that you could switch from joint to separate anywhere
along the U-Net. For example, it would be logically sound to have the first two layers of
the U-Net’s encoder be shared, but have most of the encoder (as well as the decoder)
be dual.
If you do decide to experiment with a joint-encoder, dual-decoder architecture, it would be
good practice to start with a “regular” two-channel-output U-Net as a baseline for comparison.
(You should also clarify in your own mind the details of your use case. For example, are you
predicting individual pixels to be either task A or task B or neither, so a three-class per-pixel
multi-class classification problem, or would you be making two binary predictions for each pixel
separately for each of task A and task B, so a two-class multi-label classification problem?)
Best.
K. Frank
Thank you Frank. So the thing is , one of the decoders will give a binary class output , and the other will give a multiclass output. however , the encoder extracts features from the same data space. Is that also possible?
And also will the combination of the loss functions that were dedicated to each of the decoders be used to update the weights in the encoder ?
Hi Claudia!
Yes. The first encoder would have a single-channel output – the logit for the binary-classification
prediction. The second encoder would have a multi-channel output.
Suppose that the second encoder is classifying pixels for semantic segmentation, let’s say
for “hat,” “shirt,”, “pants,” “shoes,” and “background.” Then the second decoder would have
a five-channel output – the unnormalized log-probabilities for the five segmentation classes.
In the baseline approach of using a “regular” U-Net that I mentioned in an earlier post, the last
layer of the (single-decoder) U-Net would have six channels. You would take the first channel
and pass it to BCEWithLogitsLoss
(for the binary segmentation) and pass the remaining five
channels, grouped together, to CrossEntropyLoss
(for the multi-class segmentation).
With two separate decoders, you would do essentially the same thing – use the output of the
single-channel decoder for binary segmentation and the output of the five-channel decoder
for multi-class segmentation.
You would sum the two losses together, probably with weights, into a total loss, and then
backpropagate the total loss. The gradients computed for, say, the single-channel decoder
will depend only on the BCEWithLogitsLoss
, while the gradients computed for the encoder
will depend on both the BCEWithLogitsLoss
and the CrossEntropyLoss
, so the encoder
will be trained to produce an encoding that is useful for both segmentation tasks.
Again, don’t be fixated on having two separate decoders. There may well be advantages to
having essentially all of the U-Net be shared between – and trained jointly on – both of the
segmentation tasks. You should definitely try the single-joint-decoder approach and only
switch over to the dual-decoder architecture if it is empirically seen to work better.
(You could argue that such a joint decoder “is doing more work” than one of the single-task
decoders. So you could argue that joint-decoder architecture should have more internal
channels than either of the single-task decoders does individually. To make a fair comparison
between the two architectures, you would have to experiment with that.)
Best.
K. Frank
Thank you very much Frank. I think i wasn’t clear enough on the kind of task i am doing.
Let’s say i have a dataset which has MRI data of the prostate gland. I have one groundtruth label mask for tumours (two labels) and another groundtruth label mask for the different divisions of the prostate gland (multiple labels). I want to perform these two segmentation tasks simultaneously using a UNET with a shared encoder with dual decoders whereby each decoder performs a specific segmentation task due to the two different groundtruth masks.
Based on everything that has been discussed is it possible to d something like that?