CrossEntropy in 2D softmax output

Hi there,

I’m using my last NN layer as a softmax layer for outputting a 2D normalised heatmap (probability distribution of the correct pixel in an image). I had to implement this myself for the layer

class Softmax2DProbability(torch.nn.Module):
    def __init__(self):
        super(Softmax2DProbability, self).__init__()

    def forward(self, x):
        orig_shape = x.data.shape
        return F.softmax(x.view((orig_shape[0], orig_shape[2]*orig_shape[3])), dim=1).view(orig_shape)

Assuming this layer is correct (which seems to be), how to I get a cross entropy between my NN output from this layer and the target heatmap (one hot 2D array for the pixel with the right value)?

If I reshape my tensor to use Torch’s current CrossEntropy, will autograd know automatically what to do to differentiate?

Thank you,

Unpacked code for better understanding:

class Softmax2DProbability(torch.nn.Module):
    def __init__(self):
        super(Softmax2DProbability, self).__init__()

    def forward(self, x):
        x_orig_shape = x.data.shape
        x_vectorised_image_size = x_orig_shape[2]*x_orig_shape[3]
        x_reshaped = x.view((x_orig_shape[0], x_vectorised_image_size))
        # now I can use torch's softmax
        soft_max_vectorised_x = F.softmax(x_reshaped, dim=1).view(x_orig_shape)
        # reshape to return
        return soft_max_vectorised_x.view(x_orig_shape)

Yep. Basically, if do an operation on a Variable and PyTorch doesn’t complain when you do the operation, nor during .backward(), then it is a fairly safe bet that autograd was able to differentiate it properly.

1 Like

I’m writing this reply for any human in the distant future, in some not-so-distant galaxy, who needs something similar.
The module I ended up using is generic for any number of channels. It takes a (N_BATCHES, N_CHANNELS, WIDTH, DEPTH) tensor of a batch of N_BATCHES images with N_CHANNELS each. It outputs another (N_BATCHES, N_CHANNELS, WIDTH, DEPTH) tensor containing the log probability for each batch and each channel (calculated and normalized across DEPTH and WIDTH)

class SoftmaxLogProbability2D(torch.nn.Module):
    def __init__(self):
        super(SoftmaxLogProbability2D, self).__init__()

    def forward(self, x):
        orig_shape = x.data.shape
        seq_x = []
        for channel_ix in range(orig_shape[1]):
            softmax_ = F.softmax(x[:, channel_ix, :, :].contiguous()
                                 .view((orig_shape[0], orig_shape[2] * orig_shape[3])), dim=1)\
                .view((orig_shape[0], orig_shape[2], orig_shape[3]))
            seq_x.append(softmax_.log())
        x = torch.stack(seq_x, dim=1)
        return x

It seems to work (so far).

4 Likes

Hello I still confuse with how cross entrophy loss in pytorch works in 1D data. Here is my condition:
-I have 3 classes
-Input = (NCW) output=(N,W) --> Input(64,3,640), output=(64,640)
-Actually I have tried to use nn.CrossEntrophyLoss but something wrong with dimension, and then I try to unsqueeze it and treat it as image.

  • After unsqueeze --> Input = (NCHW) output=(NHW) --> Input(64,3,1,640), output=(64,1,640)
    -I follow this implementation:
class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True, ignore_index=255):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index)

    def forward(self, inputs, targets):
return self.nll_loss(F.log_softmax(inputs), targets)

-However I got this error:
RuntimeError: invalid argument 3: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4 at /opt/conda/conda-bld/pytorch-nightly_1538995270066/work/aten/src/THNN/generic/SpatialClassNLLCriterion.c:59

I’m not sure, what Input and output are.
Based on the shapes it seems that Input is actually your model prediction, while output seems to be the target. Is this correct?
Here is a small example of using nn.CrossEntropyLoss for images, e.g. in a segmentation use case:

batch_size = 1
c, h, w = 1, 10, 10
nb_classes = 3
x = torch.randn(batch_size, c, h, w)
target = torch.empty(batch_size, h, w, dtype=torch.long).random_(nb_classes)

model = nn.Conv2d(c, nb_classes, 3, 1, 1)
criterion = nn.CrossEntropyLoss()

output = model(x)
loss = criterion(output, target)
loss.backward()

Let me know, if your use case differs.

3 Likes

Hello @ptrblck thank you for fast response. I have see the documentation again and again and finally I am choosing size NCW and target NW so it will be like (64,3,640) and target (64,640) as state in here which using log softmax followed by NLLLoss. However I got this kind of error:

RuntimeError: Assertioncur_target >= 0 && cur_target < n_classes’ failed. at /opt/conda/conda-bld/pytorch-nightly_1538995270066/work/aten/src/THNN/generic/SpatialClassNLLCriterion.c:110`

And finally I realize because the groundtruth value must normalize since probability is from 0 to 1 :rofl:. Thank you so much for your help.

Good to hear it’s working!
However, let’s check the last error message as I have the feeling there might be a silent bug.
The RuntimeError states that your target contains invalid values, i.e. it should contain values in the range [0, n_classes-1].
Basically your target tensor contains the class indices for all samples, not probabilities!

How did you normalize the target to get the probabilities?

Based on the output shape of your model is looks like you are dealing with three classes.
The target tensor should thus contain long values in [0, 2].

Sorry I have tried to produce it but it failed. However I have the same situation yesterday in the jupyter notebook so I still can see the log, but in different size. But I think in here because my target dimension is wrong. So here is the code.


The shapes of x and m look alright in case you have two classes.
Did you fail to reproduce the error? If that’s the case, it’s alright. :wink:
I just wanted to make sure I’m not misunderstanding your normalization of the target.

1 Like

Yes all is well now, thank you @ptrblck :blush:

1 Like

Basically your target tensor contains the class indices for all samples, not probabilities!

Hi, is there any loss function in PyTorch similar to CrossEntropyLoss but takes the ground truth probabilities as input instead of class indices? If not, any suggestion on how to implement it? Thanks!

In this thread a cross entropy loss was implemented using continuous (or soft) targets.
Would that work for you?

Hello @ptrblck sorry this maybe out of thread, but if we implement a loos function in new pytorch 1.0 or 0.4 is that necessary to have backward() function?
Suppose I want to implement dice coefficient loss like this:

def dice_loss(pred, target):
    smooth = 1.
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()
    A_sum = torch.sum(tflat * iflat)
    B_sum = torch.sum(tflat * tflat)    
    return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )

During the running time it run smoothly with no error, however I just don’t understand is that really do backpropagation and calculating gradient?, how can we check it?
I work in segmentation so I need to sum the loss like loss = entrophy + dice coeff is that really do backpropagation or it just calculate the entrophy loss not the dice since it is not implementing backward() function.

You need to implement the backward function yourself, if you need non-PyTorch operations (e.g. using numpy) or if you would like to speed up the backward pass and think you might have a performant backward implementation for pure PyTorch operations.

Basically, if you just use PyTorch operations, you don’t need to define backward as Autograd is able to track all operations and create the backward pass.
Your method looks fine! If you want to check for gradients, just call dice_coeff.backward() and check some layers for gradients. Something like this should work:

...
dice_coeff = dice_loss(pred, target)
dice_coeff.backward()
print(model.fc.weight.grad)

If that gives you valid gradients, you could check the range of both losses and scale one if necessary.

1 Like

I see, thank you so much for your brief explanation, how about the combination? like this:

criterion = nn.CrossEntropyLoss(weight=weight_value)
loss = criterion(outputs, labels)
dice = dice_loss(preds, labels.data)
loss = loss+dice

So when I called loss.backward() will it compute both losses gradient? or do I need to do 2 times backward for each loss?

You can just call loss.backward() on the summed loss.
However, is there a reason you use labels.data for the dice loss instead of directly the tensor?
Probably it won’t make any errors, but the usage of .data is not recommended generally.

1 Like

No problem @ptrblck I just still remember the old pytorch syntax and need to adapt with the version 1.0 now. Still in learning. Thank you so much for explanation :smiley: .

1 Like

Hello Paulo,

Should the number of classes be equal to the number of channels as given here in your code snippet?

NLLloss2d

I have a similar problem statement where each of my pixels at the end of the convolution layer should have a label between (0-9). I wonder, shouldn’t NLLloss2d work for you as well?

Regards
Surojit