How to create a tensor with a channel for classes

Hello,
I have a problem. I have done a CNN which gives me an output of the form : (BATCH_SIZE, 1, HEIGHT, WIDTH). However, in order to apply the CrossEntropyLoss function, I need an output of the form, (BATCH_SIZE, NB_CLASS, HEIGHT, WIDTH). How can I get such a matrix from my original output? I currently have three classes, I want to do image segmentation and my output images are obviously float32.

Thank you for your answers :slight_smile:

Hi William!

As an aside, it sounds like your current CNN performs binary (i.e.,
background-foreground) segmentation.

It sounds like your desired use case is multi-class (with NB_CLASS
classes) semantic segmentation.

First, once you’re down to (BATCH_SIZE, 1, HEIGHT, WIDTH), that
is, a single output channel, you will no longer be able to recover your
desired NB_CLASS channels. You will have to modify the final bit
of your CNN architecture, and the details will depend on your specific
architecture.

It is likely that your next-to-last layer will produce an “image” with
a number of “feature” channels, and your last layer will recombine
your feature channels into the desired number of “class” channels.

For example, quoting from the original U-Net paper, “At the final layer
a 1x1 convolution is used to map each 64-component feature vector
to the desired number of classes.”

In the context of pytorch, you might be looking for a final layer of
the form:

torch.nn.Conv2d (in_channels = 64, out_channels = 1, kernel_size = 1)

(This would be for the case that your next-to-last layer produces 64
“feature” channels.)

In such a case you would change out_channels to NB_CLASS, that
is, replace your final layer with:

torch.nn.Conv2d (in_channels = 64, out_channels = NB_CLASS, kernel_size = 1)

This illustrates the basic idea for classic U-Net – your specific CNN
may differ in detail.

Best.

K. Frank

1 Like

Thank you for your answer. It was very clear. I couldn’t get my segmentation to work (it only sends me black images as output) but now my CrossEntropyLoss is working. Thanks a lot!