Multi-Class Cross Entropy Loss function implementation in PyTorch

If you look at the formula for CrossEntropyLoss you see, that only the logit for the current target class is needed (x[class]).
So instead of masking the logits with a one-hot encoded target tensor, we just use the index for the appropriate class.

The line of code you cited from my post was just to create a target image with the same properties you have in the kaggle dataset. You shouldn’t use it on your target!

Hi @ptrblck

Could you give me a code snippet showing where I need to make the required changes? Should these changes be made to the dataset class at the time of returning the mask or in the train loop, just before sending it to the loss function?

My train loop looks like this at the moment:

       for i, (train_batch, labels_batch) in enumerate(dataloader):
            # move to GPU if available
            if params.cuda:
                train_batch, labels_batch = train_batch.cuda(async=True), labels_batch.cuda(async=True)
            # convert to torch Variables
            train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)

            # compute model output and loss
            output_batch = model(train_batch)

            # get targets from labels
            target = F.softmax(labels_batch, dim=1) > 0.5  # target now has a 1 in channel c if the pixel location belongs to that class
            labels = torch.argmax(target, dim=1)           # labels now has indices for the corresponding class

            # compute loss
            loss = loss_fn(output_batch, labels)
            logger.debug("train loss: {}".format(loss.data.item()))

            # clear previous gradients, compute gradients of all variables wrt loss
            optimizer.zero_grad()
            loss.backward()

            # performs updates using calculated gradients
            optimizer.step()

I’m not getting the right output when running a test image. All the masks looks the same, in terms of output, so the NN did not really learn anything in terms of discriminating the different classes. This is with a very small train set (20 images train, 2 dev, 3 test). Once I get the main train and eval loops fixed, I will do data augmentation and generate (50k train, 10k dev and 20k test).

This is the rgb test image:

These are the masks that the NN generated from the trained weights… they should all have been different, or atleast blank for most of the classes for this test image, since there are for example, no buildings in this test image:

ch01
ch02
ch03
ch04
ch05

20 training images are indeed very little.
You don’t need to call target = F.softmax(labels_batch, dim=1) > 0.5, when your target images are already binary images. I just added this line in my sample to create a fake binary target image.

The other line (labels = torch.argmax(target, dim=1)) could go into the Dataset's __getitem__ method.

Are your training predictions looking good?

Hi @ptrblck

How can I reverse this, so that I can debug it for display?, i.e. get back the [ch=10, h=256, w=256]

mask = torch.from_numpy(mask).float()
mask = torch.argmax(mask, dim=0)  # map target classes to indices, shape will be [1, 256, 256]

The argmax should be called on dim=1, but I suppose that’s just a typo.
If your mask is of dim [batch_size, height, width], you could try the following:

target = torch.zeros(1, 10, 24, 24)
target.scatter_(1, mask.unsqueeze(1), 1)

Hi @ptrblck

The generated mask from the dataset has the shape [10, 256, 256].

(If I apply dim=1, its shape becomes [10, 256].)

You have to use dim=1 if you have a batch dimension (which is not the case inside the dataset) and dim=0 if you don’t have one.

Ok. Then I assume dim0 is the batch dimension?
If so, your mask should already contain the class indices.

EDIT: @justusschock is probably right! I think I’ve seen the data outside of the Dataset.

I think dim0 is the channel dimension because the batch dimension is not present inside the dataset but is added inside the dataloader afterwards.

1 Like

The dataset output shape for a single mask is [10, 256, 256]. The dataset getitem returns just 1 image and 1 mask.

This is the code fragment for my dataset class, just before it returns the image and mask. Image is [3, 256, 256], mask is [10, 256, 256]. U-Net model is 3-ch input and 10-ch output.

Should I be doing the class to index mapping here in side the dataset loader class, as follows:

        # TODO: Remove this. Temporarily resizing both image and mask.
        image = resize(image, 256, 256).transpose((2, 0, 1))
        mask = resize(mask, 256, 256).transpose((2, 0, 1))
        #mask = mask[0:3, :, :] # extract only the first channel of the mask

        self.logger.info("image type: {}, image shape: {}, image max pixel value ch0: {}".format(image.dtype, image.shape, np.amax(image[0, :, :])))
        self.logger.info("mask  type: {}, mask  shape: {}, mask  max pixel value ch0: {}".format(mask.dtype, mask.shape, np.amax(mask[0, :, :])))

        # TODO: Check if we have to convert the image and mask to torch tensors here?
        image = torch.from_numpy(image).float()
        mask = torch.from_numpy(mask).float()
        mask = torch.argmax(mask, dim=0)  # map target classes to indices

        self.logger.info("image type: {}, image shape: {}".format(image.dtype, image.shape))
        self.logger.info("mask  type: {}, mask  shape: {}".format(mask.dtype, mask.shape,))

        return image, mask
2018-06-04 13:35:23 INFO     | dataset:__getitem__:355: image type: float64, image shape: (3, 256, 256), image max pixel value ch0: 0.7590009134675589
2018-06-04 13:35:23 INFO     | dataset:__getitem__:356: mask  type: uint8, mask  shape: (10, 256, 256), mask  max pixel value ch0: 1
2018-06-04 13:35:23 INFO     | dataset:__getitem__:364: image type: torch.float32, image shape: torch.Size([3, 256, 256])
2018-06-04 13:35:23 INFO     | dataset:__getitem__:365: mask  type: torch.int64, mask  shape: torch.Size([256, 256])
2018-06-04 13:35:23 INFO     | train_unet:train:85: train output_batch.shape = torch.Size([1, 10, 256, 256]). labels_batch.shape = torch.Size([1, 256, 256])
2018-06-04 13:35:23 INFO     | evaluate_unet:display_mask_ch:145: mask shape: (256, 256, 10)

Looks good!
I don’t understand the last line in your logging.
The mask shape seems to be [W, H, C]? I assume you are trying to visualize it with matplotlib or another library.
Besides that, the code should run.

yes, that was in another file, before I converted it back to [H, W, C] for matplotlib. I’ll modify it later on to push these predicted mask images to disk, and try to display it in tensorboard, so that I can see how the network learns for each channel.

Hi @ptrblck

This is the output of my model

train_unet:train:85: train output_batch.shape = torch.Size([1, 10, 256, 256]), labels_batch.shape = torch.Size([1, 256, 256])

for a labels_batch of shape [1, 256, 256], which was the original generated from the dataset class, and sent to the model via the dataloader, I tried to do

labels_batch_cpu = labels_batch.data.cpu()
labels_batch_cpu.scatter_(0, labels_batch_cpu.unsqueeze(0), 0)

but get this error.

File "/project/geospatial/application/cs230-sifd/source/main/train/train_unet.py", line 110, in train
    labels_batch_cpu.scatter_(0, labels_batch_cpu.unsqueeze(0), 0)
RuntimeError: invalid argument 3: Index tensor must have same dimensions as output tensor at /opt/conda/conda-bld/pytorch_1524590031827/work/aten/src/TH/generic/THTensorMath.c:661

Am I correctly calling the scatter/unsqueeze operation on dim 0?

The tensor calling scatter_ should be a new one, not the label tensor.

tensor = torch.zeros(1, 10, h, w)
tensor.scatter_(1, labels_batch_cpu.unsqueeze(1), 1)

Hi @ptrblck

I just added a bit of debugging code in my dataset generator, to see if the masks are being sent to the model correctly.

        mask = self.mask_generator.mask(id=id_, height=h, width=w)

        # TODO: Remove this code
        import matplotlib.pyplot as plt
        # display individual mask channels
        for i in range(mask.shape[-1]):
            mask_ch = mask[..., i] * 255
            plt.figure(figsize=(10, 10))
            plt.title('mask ch{}'.format(i + 1))
            plt.imshow(mask_ch, cmap=plt.cm.gray)
            plt.show()

These are how the individual channel masks look like, corresponding to builds, roads, etc. There are 10 classes in total, which have been mapped to 10 channels in the generated mask. Here are the first 3 masks:

I just took a look at the labels_batch, and after unsqueezing it and trying to display the first channel, I see that the image contains a sum of the individual classes and it is inverted and a fully concatenaed negative mask!

So, the neural network doesn’t learn anything for the individual channels because the target labels are all concatenated for the individual channels.

Let me debug this, and make sure the class to index tensor mapping operation and subsequent recovery of the original data is correctly working.

This might be a visualization issue, since your unsqueezed image contains one color channel and pixel values indicating the classes. matplotlib or another lib might clip these values.
Try to visualize labels_batch[labels_batch=class] and see if you get the individual class target images.

Hi @ptrblck

I wrote a small test case, for a batch size of 1:

    def test_mask_class_to_index_tensor_mapping(self):
        self.logger.info("generating mask")

        # mask parameters, which includes the number of channels to be included in the generated mask
        self.logger.debug("mask parameters: \n{}\n".format(pformat(self.mask_params)))

        # select a sample
        image_id = self.dataset_partition_params['train'][0]

        # generate mask
        mask = self.mask_generator.mask(id=image_id, height=3349, width=3391)
        self.logger.debug("generated mask shape: {}".format(mask.shape))

        # resize mask
        mask = resize(mask, 256, 256)

        # display mask
        display_mask(mask)

        # resize mask
        mask = resize(mask, 256, 256).transpose((2, 0, 1))

        # convert to torch variable
        mask = torch.from_numpy(mask).float()
        self.logger.info("mask tensor type: {}, mask  shape: {}".format(mask.dtype, mask.shape))

        # map target classes to tensor indices
        mask = torch.argmax(mask, dim=0).div(9.0)
        self.logger.info("mask shape after class to tensor index mapping: {}".format(mask.shape))

        """
        Now let' try to emulate the dataloader and retrieve the individual 
        mask channels.
        """

        # emulate adding an extra batch dimension by the dataloader
        labels_batch = torch.unsqueeze(mask, 0)
        self.logger.info("labels_batch shape after unsqueeze: {}".format(labels_batch.shape))

        # convert labels_batch back to target classes, for visual debug purposes
        # the unsqueezed image contains one color channel and pixel values indicating the classes
        n, h, w = labels_batch.shape
        tensor = torch.zeros(n, 10, h, w)
        tensor.scatter_(1, labels_batch.unsqueeze(1), 1)
        self.logger.info("label tensor shape: {}".format(tensor.shape))

        # convert the label tensor back to a numpy array
        label_mask = tensor.numpy()[0, :, :, :].transpose([1, 2, 0]) * 9.0  # denormalize the mask values
        self.logger.info("label_mask shape: {}".format(label_mask.shape))
        display_mask(label_mask)
def display_mask(mask):
    logger.info("mask shape: {}".format(mask.shape))
    # display individual mask channels
    for i in range(mask.shape[-1]):
        mask_ch = mask[..., i]
        plt.figure(figsize=(10, 10))
        plt.title('mask ch{}'.format(i + 1))
        plt.imshow(mask_ch, cmap=plt.cm.gray)
        plt.show()

When I try to display it, this is what happens:

  • The original ch9 mask, the last mask, appears in two places in the labels batch, one as a negative image at channel location 0, and the other a correct image ch9 mask image appearing at channel location 1.
  • The rest of the labels batch channels 2 to 9 are black and devoid of pixels.

This is the original mask at ch9 from the mask generation process:

Here are the labels_batch images after the extraction process:

labels_batch_ch0: negative image of original ch9 mask.

labels_batch_ch1: correct image, but at wrong channel, should appear at ch9

labels_batch_ch2_to_ch9: empty, no pixels, it should have had some content.

There is something going wrong when unmapping the indices back to classes for display.

Any thoughts on what might cause this? I’ve put the full test case above.

BTW, the debugger shows that the class to index mapping is working correctly. I can see several entries with a 5, and 4 and 1, as it comes out of the dataset and the dataloader.

I have an additional question, should I divide the values in the mask by 9.0, to normalize it in the range [0 to 1]. Something like this:

image = torch.from_numpy(image).float()
mask = torch.from_numpy(mask).float()
mask = torch.argmax(mask, dim=0).div(9.0)  # map target classes to indices, and normalize it

Additionally, should I center all the values to be between [-1, 1] for both the image and the mask, as it gets generated from the dataset?

Do you get any errors running your mapping code?
Apparently mask is a normalized FloatTensor, thus labels_batch is of the same type with an additional batch dimension.
The .scatter_ call should throw an error, since the index should be a LongTensor.
I am wondering, why you can display anything at all, because tensor should be empty.

Anyway, what you are seeing is an error probably due to a missing background class.
Imagine you just have the class channels (channel 0 to 9). Each 1 represents an occurrence of the corresponding class at this pixel position.
Each pixel has one unique class, i.e. there should be only one 1 for the pixel position in all channels.
Now by using the transformation and scatter, you are getting the argmax at some point.
Since there are some pixels without any class correspondence, i.e. even without a background class,
these pixels will get an argmax of 0.
This puts all background pixels (pixels without any class) into the first class channel (channel 0), which will mix up your class0 and background.
Therefore, we have to introduce a background class.

I created a small example showing the transformation:

# Create a target
target = torch.zeros(24, 24, 10)
# Draw "segmentation masks"
for c in range(10):
    target[c*2:c*2+1, c*2:c*2+1, c] = 1
    
display_mask(target)

# Create background class
background = torch.ones(24, 24, 1) - (target[:, :, :10] == 1).float().sum(2, keepdim=True)
target = torch.cat((target, background), dim=2)
display_mask(target) # the last channel is the background class

mask = target.permute(2, 0, 1)
mask = torch.argmax(mask, dim=0)
labels_batch = mask.unsqueeze(0)
n, h, w = labels_batch.shape
tensor = torch.zeros(n, 11, h, w)
tensor.scatter_(1, labels_batch.unsqueeze(1), 1)
label_mask = tensor.numpy()[0].transpose(1, 2, 0)
display_mask(label_mask)

Let me know, if I misunderstood something.

I’m not getting an error while performing the .scatter_ call, for the two cases when the mask tensor is either float() or long(). Perhaps you can add an assertion to the .scatter_ method, similar to the ones raised by the loss functions?

       # generate mask
        mask = self.mask_generator.mask(id=image_id, height=3349, width=3391)

        # resize mask
        mask = resize(mask, 256, 256)

        # display mask
        display_mask(mask)

        # convert to torch type (c x h x w)
        mask = mask.transpose((2, 0, 1))

        # convert to torch variable
        mask = torch.from_numpy(mask).float()  # doesn't error out for both float() and long()
        self.logger.info("mask tensor type: {}, mask  shape: {}".format(mask.dtype, mask.shape))

        # map target classes to tensor indices
        mask = torch.argmax(mask, dim=0) 

        """
        Now let' try to emulate the dataloader and retrieve the individual 
        mask channels.
        """

        # emulate adding an extra batch dimension by the dataloader
        labels_batch = torch.unsqueeze(mask, 0)
        self.logger.info("labels_batch shape after unsqueeze: {}".format(labels_batch.shape))

        # convert labels_batch back to target classes, for visual debug purposes
        # the unsqueezed image contains one color channel and pixel values indicating the classes
        n, h, w = labels_batch.shape
        tensor = torch.zeros(n, 10, h, w)
        tensor.scatter_(1, labels_batch.unsqueeze(1), 1)
        self.logger.info("label tensor shape: {}".format(tensor.shape))

        # convert the label tensor back to a numpy array
        label_mask = tensor.numpy()[0, :, :, :].transpose([1, 2, 0])
        self.logger.info("label_mask shape: {}".format(label_mask.shape))
        display_mask(label_mask)