Hi Michael!
Does this mean that:
-
The large structure is – either in the real world and/or in your
target data – exactly the union of the three substructures?
-
The three substructures have exactly zero intersection with
one another?
If so, would it suffice to simply segment the three substructures, with
the large structure being determined by post-processing that takes
the union of the substructures, rather than directly by the model itself?
You could argue that getting the large structure even more correct
is more important that having the large structure fully consistent
with the substructures, and that having the model segment the
large structure (in addition to the substructures) lets you do this.
But there is something to be said for at least trying the simpler
approach.
I’ll mention CrossEntropyLoss
again, below, but for now: It’s perfectly
reasonable to add Dice loss to CrossEntropyLoss
, but my general
advice would be to start with CrossEntropyLoss
as the default, and
only add Dice loss if it gives you a clear improvement.
If you don’t need to explicitly predict the large structure (because it
can be adequately recovered as the union of the three substructures),
you will be performing a single-label, four-class (background,
substructure-1, substructure-2, substructure-3) segmentation problem.
For this you will want CrossEntropyLoss
(adding Dice loss, if that
clearly helps).
Speaking in terms of a 2d, grayscale image, the input to your model
should have shape [nBatch, width, height]
, and the output
(which will become the input
to CrossEntropyLoss
) should have
shape [nBatch, nClass = 4, width, height]
. From the masks
you have, you should build a single mask (per image) that consists
of integer class labels (with values 0 through 3) and has shape
[nBatch, width, height]
. (Note that there is no nClass
dimension.)
The integer-class-label mask will be the target
for CrossEntropyLoss
.
If you want to also explicitly predict your large structure, either because
it is different than the union of the substructures, or because predicting
it separately is (usefully) more accurate than taking the union, then I
would recommend performing the two predictions is a single network
as follows:
Understand the large-structure prediction as a single-label, binary
segmentation problem, and the substructure predictions as a
single-label, four-class segmentation problem (as described above).
Make your last Linear
layer have five outputs, so that the model
output has shape [nBatch, 5, width, height]
. Understand the
first of your five outputs to be your large-structure prediction, and
the remaining four to be your substructure predictions. Make a binary
(0-1) mask for your large structure, and make a second four-class
(as above) mask for your substructures. Peel off the first of your
five outputs and feed it into BCEWithLogitsLoss
together with
your large-structure mask, and peel off the last four outputs, feeding
them into CrossEntropyLoss
together with the class-label mask.
Add the two losses together (probably with some relative weight
that you tune by hand) to get a combined loss (that you call
loss.backward()
on).
Your large structure and substructures are closely related, so it
makes sense to train a single model jointly on these two segmentation
tasks. Most of the upstream processing and features are shared by
the two tasks, and only the final Linear
layer “learns” how to predict
the large structure and the substructures separately from the upstream
features.
Based on my assumptions about your use case, this approach seems
to do what you most naturally want. You have both a single-label,
binary segmentation task, and a single-label, four-class segmentation
task, and this approach uses the (generally) most applicable methods
to train these two tasks.
Lastly, to clarify:
You would typically use BCEWithLogitsLoss
for a multi-label,
multi-class problem, in which case your target image (mask) should
be, if you will, multi-hot encoded. That is, there can be a 1
in the
position for any number of your classes, including none and all of
them. (In this case, you would not have a background class – the
background “class” would be indicated by all 0
’s, that is, 0
’s for all
of your foreground classes.)
If you can, in fact, one-hot encoded your target – exactly one of your
classes (including the background class) is active at a time – then
you have a *single-label, multi-class problem, and, in general,
CrossEntropyLoss
will be the better choice (because, in essence,
your network “knows” that it’s being trained for a single-label task).
Good luck.
K. Frank