Segmentation Network Loss issues

Hello,

I’ve read quite a few relevant topics here on discuss.pytorch.org such as:

I’ve tried with CrossEntropyLoss but it comes with problems I don’t know how to easily overcome.
So I’m now trying to use nn.NLLLoss with pytorch 1.3 after running the network logits through torch.nn.functional.log_softmax

The way the dataset is organised, is that each input has a 7-channel tensor, where each channel is a class, and then the dims (224x224) are the pixels for that class.
So because I can have multiple classes in an image, I want them all indexed, across their respective channels.

I’ve followed the instructions from the above topics, but I keep getting the same error below:

logits: torch.Size([32, 7, 224, 224])
target: torch.Size([32, 7, 224, 224])
Traceback (most recent call last):
  File "train_segnet.py", line 183, in <module>
    loss   = criterion(logits, labels.cuda().long())
  File "venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "venv/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 204, in forward
    return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
  File "venv/lib/python3.6/site-packages/torch/nn/functional.py", line 1840, in nll_loss
    ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [32, 7, 224, 224]

I thought that NLLoss is capable of dealing with multiple dimension tensors?
Should I squeeze them all across a single axis? How do I solve this issue so that the network
can learn to produce those tensors as output?

Based on your description I assume you are working on a multi-class segmentation use case, i.e. each pixel belongs to one of the 7 classes.
In that case, the target should not contain the channel dimension (number of classes), but instead have the shape [batch_size, height, width] and contain values in the range [0, nb_classes].

Here is a small dummy code using your shapes:

nb_classes = 7
batch_size, height, width = 32, 224, 224

output = torch.randn(batch_size, nb_classes, height, width, requires_grad=True)
target = torch.randint(0, nb_classes, (batch_size, height, width))

criterion = nn.CrossEntropyLoss()

loss = criterion(output, target)
loss.backward()
1 Like

What’s the advantage of doing this? This is the way I originally tried to implement it, but I couldn’t work out the accuracy or loss implementation. I’ll try doing it this way and see what happens :slight_smile:

Also, would you recommend CrossEntropyLoss, NLLloss or some other function for this sort of thing, and why one over the other?

So I tried what was suggested, and I’m still quite confused, since now I am getting the following error:

logit shape torch.Size([32, 224, 224])
label shape torch.Size([32, 224, 224])
Traceback (most recent call last):
  File "train_segnet.py", line 185, in <module>
    loss   = criterion(logits, labels.cuda().long())
  File "venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "venv/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 916, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "venv/lib/python3.6/site-packages/torch/nn/functional.py", line 2009, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "venv/lib/python3.6/site-packages/torch/nn/functional.py", line 1848, in nll_loss
    out_size, target.size()))
ValueError: Expected target size (32, 224), got torch.Size([32, 224, 224])

However:

  1. the network seems to output values in a different range (I am using DeepLabV3('resnet18', classes=1) which is from this repository here: https://github.com/qubvel/segmentation_models.pytorch
  2. that was the reason I got everything placed in different channels (the repo uses that approach)
  3. I’ve placed my data as you suggested, however that doesn’t seem to work for me right now.

One of the errors I am currently getting is:

/pytorch/aten/src/THCUNN/SpatialClassNLLCriterion.cu:104: void cunn_SpatialClassNLLCriterion_updateOutput_kernel(T *, T *, T *, long *, T *, int, int, int, int, int, long) [with T = float, AccumT = float]: block: [7,0,0], thread: [991,0,0] Assertiont >= 0 && t < n_classesfailed

Which makes sense because when I look at the network output/logits they are:

[ 0.3511,  0.3759,  0.4006,  ...,  0.2662,  0.2667,  0.2672],
          [ 0.3513,  0.3746,  0.3980,  ...,  0.3169,  0.3135,  0.3100],
          [ 0.3514,  0.3734,  0.3954,  ...,  0.3677,  0.3602,  0.3528],
          ...,
          [ 0.2786,  0.2707,  0.2628,  ...,  0.0179, -0.0218, -0.0616],
          [ 0.2405,  0.2318,  0.2231,  ..., -0.0140, -0.0691, -0.1243],
          [ 0.2024,  0.1928,  0.1833,  ..., -0.0458, -0.1164, -0.1870]]],


        [[[-0.1596, -0.0971, -0.0346,  ...,  0.2567,  0.2747,  0.2928],
          [-0.0786, -0.0282,  0.0221,  ...,  0.2614,  0.2682,  0.2750],
          [ 0.0025,  0.0406,  0.0787,  ...,  0.2662,  0.2617,  0.2572],
          ...,
          [ 0.2729,  0.2585,  0.2440,  ..., -0.1774, -0.2158, -0.2543],
          [ 0.2877,  0.2605,  0.2333,  ..., -0.2256, -0.2790, -0.3325],
          [ 0.3024,  0.2625,  0.2225,  ..., -0.2739, -0.3423, -0.4107]]],


        [[[ 0.0162,  0.0238,  0.0313,  ...,  0.0944,  0.0774,  0.0605],
          [ 0.0553,  0.0645,  0.0736,  ...,  0.1529,  0.1334,  0.1138],
          [ 0.0943,  0.1052,  0.1160,  ...,  0.2115,  0.1893,  0.1671],
          ...,
          [-0.0165,  0.0217,  0.0599,  ..., -0.0660, -0.0712, -0.0764],
          [-0.0538, -0.0219,  0.0100,  ..., -0.0507, -0.0527, -0.0547],
          [-0.0912, -0.0655, -0.0399,  ..., -0.0355, -0.0342, -0.0329]]],


        ...,

So, is there some transformation I should apply to the logits?
Or should I go back to dealing with 7-channel output, and then translate those, to a 1-dim tensor which can be used to calculate losses, in the shape you have suggested?
I can obviously go from 7-channel per class, to 1-channel with class index values, and then find a way to translate those so that the losses can be calculated, assuming that will work as per your example.

Your logit output shape is missing the class dimension.
In my code snippet I’m creating the logits as [batch_size, nb_classes, height, width] and the target es [batch_size, height, width]. If you stick to these shapes, it should work.

nn.CrossentropyLoss expects logits and uses F.log_softmax + nn.NLLLoss internally, so these approaches will yield the same result.

1 Like

I see now why this has been so confusing. I’ve run your example and realised what’s going on.
The actual logits are of the shape of [batch, class_channel, height, width] which is what I’ve had all along. The labels/targets/ideals are in the shape of [batch, height, width]. That is the confusing part, and I didn’t really get why it’s like this, until I saw what was happening at the logits:

  • if they didn’t have a separate channel for the class or label I suspect there is no way to transform their raw/logit value to a class representation? Which is the problem I’ve run into twice so far.
    So I transformed the labels into a 1-dim tensor, where each pixel represents a class index:

So I wrote this:

def segnet_label_translate(labels, num_classes):
    labels = labels.cuda().long()
    cuda0  = torch.device('cuda:0')
    b_size = labels.shape[0]
    tensor = torch.zeros([b_size, 224, 224], dtype=torch.long, device=cuda0)
    zeros  = torch.zeros([224, 224], dtype=torch.long, device=cuda0)

    for batch in range(b_size):
        for i in range(num_classes):
            ideal    = labels[batch][i]
            label    = i  + 1
            values   = torch.full([224, 224], label, dtype=torch.long, device=cuda0)
            filtered = torch.where((ideal != 0), values, zeros)
            tensor[batch] = filtered

    return tensor

I suspect there are easier of faster ways to write the above^.
That seems to work when I remove the line which does label = i + 1 otherwise I get the assertion triggered. However, what is unclear is why would I use class zero.
Then I remembered reading somewhere that class 0 in segmentation is always the background or some other less useful class?

The above works if I don’t do label = i + 1 but for me class 0 is something valid that I look for, so I suppose I’ll have to account for that.

Also, should I be applying a softmax or any other transformation to the logits, or is there no need for it?

I’m not completely sure, what your translate method is doing.
How did you define the target at the moment?

Class0 has the same importance as any other class. It is often used as the background, but you can assign whatever class you want to it.
Note that Python uses 0-indexing, which makes class0 just the first class.

No, as said above, nn.CrossentropyLoss expects logits and uses F.log_softmax + nn.NLLLoss internally, so you should not add a softmax layer.

@ptrblck It takes the current [batch_size, nb_classes, height, width] target which is spread across 7 channels (8 now in order to add background) and then squeezes the nb_classes dimension from 8 to 1, since my pixels/classes are never overlapping (I’m dealing with documents).

Hence it iterates the [batch_size, nb_classes, height, width] and transforms it to [batch_size, height, width] instead, whereas before each channel had a binary value for each pixel, it now has the equivalent long integer representing the class, as you suggested.

I’m pretty sure someone else (you most likely) can do the above in a simpler and more elegant way.
PS: b_size is batch size, and labels[batch] is the [nb_classes, height, width] tensor.

Thanks for the explanation.
If each channel is one-hot encoded, you could try to use target = torch.argmax(target, dim=1) instead.

1 Like

Also, I forgot to ask. When deployed, how can I make sense of the logits in order to translate them into the pixel coordinates of a specific class?

You get the predicted class for each pixel also via: preds = torch.argmax(output, dim=1).
This will return the indices for the maximal values in dim1, which corresponds to the class dimension.

That’s interesting. This seems to return only one of the labels, not the others which I know are present (and the network has been trained). What could cause this?

I’m uploading an example of what I mean,

3ed9367921a6a734676def01597407ab2e47aa47dd5376c07315a0f9be09e087c3f9704d8ce0c5d2546e30167747657d04527176911b5892836251c87c95f660

In this picture (it is a human-readable mask) I’m only getting the green region (class 7) but not the other two regions (class 2 and class 3).

This is trained with the latest samples above, and with a pixel-wise accuracy of 99%.

How are you calculating the pixel-wise accuracy?
You would have to get the predicted classes somehow to calculate it, so you could also use it to visualize the prediction.

I am not sure that using Cross Entropy does what I want it to; I suspect it trains it to maximise the class index, however in a single image I can have multiple classes present at various pixel locations. So far the network seems to learn to detect only the pixels of the class with the highest index?

Pixel wise accuracy is done by the following snippet (I had assumed softmax was needed), as such the code below won’t work without it, that is because each channel-per-class had a binary value.

           input = images.cuda(non_blocking=True).float()
            labels = labels.cuda(non_blocking=True).long()
            outputs = model(input)

            b_size = labels.shape[0]

            # iterate labels in dim 0, e.g., batch size
            for i in range(b_size):
                #
                #   squeeze and transform the [7,224,224] to
                #   a single dim [351232] of pixel values 
                #   all concat across one dim for all channels
                #
                label   = labels[i].view(1,-1).squeeze()
                actual  = f(outputs[i].view(1, -1).squeeze())
                l_size  = labels[i].shape[0]
                #
                #   pixel-wise comparison (note this requires softmaxed/min-maxed values)
                #
                actual  = (actual > 0.9).float() * 1
                same    = (actual == label).sum()
                total   = l_size * 224 * 224
                match   = same.item() / total
                #
                # 50176 is the 224 x 224 so we need to 
                # divide and normalise the correct values now
                #
                correct.append(match)
                total   += 1

I now suspect that what you had originally suggested (a target with single dim) would make more sense as it would allow me to argmax and calculate accuracy based on that.

So I’ve reverted to using targets which are in shape of [batch, height, width] and their values are the class index (so no more channel per class).

Which AFAIK is what you have suggested. The loss function seems to work without hacking into any type of tensor or channels, simply with torch.argmax, however now my accuracy doesn’t work on a pixel-wise level, so I’m trying to implement it this way:

            outputs  = model(input)
            preds    = torch.argmax(outputs, dim=1)
            matching += (preds == labels).sum().item()
            correct  += (matching / 224 * 224)

            b_size  = labels.shape[0]
            total   += b_size

This seems to give a lower Top-1 score than before, but not by much (98% vs 99%).

So I am testing it again with the deployment platform, and after running the following:

        t_in   = self.tf(image)
        t_in   = torch.unsqueeze(t_in, 0)
        logits = self.net(t_in)
        torch.set_printoptions(profile="full")
        torch.set_printoptions(precision=5)
        #logits = logits.squeeze()
        print(logits.shape)
        print(logits)
        preds  = torch.argmax(logits, dim=1)
        print(preds.shape)
        print(preds)
        return self.extract(image, logits)

I noticed that the pred now contains the pixel class index correctly for all the classes present in the document.

So, THANK YOU, I owe you a few rounds of beer, your help is greatly appreciated!!!

1 Like

Haha, good to hear it’s working now! :slight_smile: