How to apply a element wise cross_entropy/nll loss to a 3D matrix?

Hi guys, I’m working on a 3D Unet model. Currently I’m facing a trouble in implementing the loss function for 3D tensor. The output of my model is 1 x 2 x 64 x 64 x 64 (batchsize, class, H, W , Layer) and the groudntruth is
1 x 64 x 64 x 64. However, the nll_loss function does not allow me to input a 5D tensor. It seems different from the documentation. Could someone please help me with this problem?

What version of pytorch are you using?

You might be looking at the master version of the docs: http://pytorch.org/docs/master/nn.html?highlight=nllloss#torch.nn.NLLLoss. NLLLoss for multiple dimensions was recently added; you’ll have to build from source to use it.

I see, my current version is 0.3.0 from conda. Could you please give me any suggestion to solve this problem in 0.3.0?

In 0.3.0, you’d have to do the following workaround. You can reshape your input into a 4D tensor and your target into a 3D tensor, and send them into the loss function, as so:

input_4d = input.view(1, 2, 64, -1)
target_3d = target.view(1, 64, -1)
loss = nn.NLLLoss(reduce=False)
out_3d = loss(input_4d, target_3d)
out = out_3d.view(1, 64, 64, 64)
1 Like

This is really helpful. Thanks a lot.

The next version of pytorch will have this behavior built-in, as you’ve seen in the docs!

2 Likes

@Ryanzsun can you tell me how you converted 3D images into channels, as described in the paper
like X x Y x Z x 3( how a medical data with n number of slices can have 3 channels) did you converted each slice to RGB?

Actually I didn’t convert the image into rgb channels. Most of the medical image contains only one channel. This question was about the output format of NN.

@Ryanzsun yes exactly, Actually i am trying to implement 3DUNet, but the model has (X,Y,Z,3) dimensions as input, and in the paper it is written that there are 3 channels, so what are these 3 channels then?

My data has only one dimension which is HU value for MRI image. In some cases, for example color ultrasonography, may contains rgb channels. It’s totally depends on your dataset. You should change your model according to the data, you don’t have to stick to the paper.

What about pytorch 1.1.0 for cross entropy loss? My input is 3d: [#batch_size, #class_index, #scores] , while the target is 2d: [#batch_size, #class_index]. I’ve asked the question here, but couldn’t find a good solution yet: Cross entropy loss for sentence classification

The answer has been provided here.