Multi-Class Cross Entropy Loss function implementation in PyTorch

hi @ptrblck

I’m going to do the background mask generation in numpy, since I’m using it to processing all my images and generate masks, and then convert it to torch tensors in one step.

I’d like to do the equivalent of this torch code in numpy:

background = torch.ones(24, 24, 1) - (target[:, :, :10] == 1).float().sum(2, keepdim=True)

Does the following code snippet for using numpy.logical_or with reduce look okay to you, or is there a more efficient way? I guess the most efficient route is to do all processing in the GPU using pytorch, but I’ll rework my code after I get the main NN model working.

background = np.expand_dims(np.logical_or.reduce((mask[:, :, :10] == 1), axis=2), axis=2)
background = np.ones((height, width, 1)) - background

Hi @ptrblck

I have a few questions about types, sizes, normalization.

Q01: Are the following types correct, for input to the model

image is of type float
mask is of type long

Q02: Is it possible to have the size of the image and mask optimized to use fp16 and uint8, to reduce GPU memory requirements? I will be increasing the number of input channels to 20, once I’ve validated that the model works with 3-channels, so it will demand a lot of GPU memory.

Q03: If you have multiple GPUs working with DataLoaderParallel, will the memory get distributed for a single batch, e.g. if I have a batch size of 6 and 2 GPUs, will each GPU train with 6 images for every epoch? I think this is what I see happening, not quite sure.

Q04: Normalization. At the moment I’m normalizing the images to [0, 1]. Should I also center it to [-1. 1] before sending it for training?

What about the mask. Do I leave it alone, with no normalization, and it doesn’t make a difference because the values of the pixels in each channel are either 0.0 or 1.0 float?

Your numpy result image looks good so I would go for it now!

Q01: Yes, if you are using a criterion for classification, e.g. nn.CrossEntropyLoss or nn.NLLLoss, that’s correct.

Q02: It should be possible, but I’m not sure, how easy the workflow is. Especially, how easy or hard the training might be. Since my GPU doesn’t support native float16, I haven’t used it yet. What kind of GPU do you have? Alternatively, you could trade some compute for memory using torch.utils.checkpoint, if your memory is limited.

Q03: I assume you are referring to DataParallel. No, the dimension 0 will be distributed among the GPUs, so in the usual workflow, the batch will be split. If you have 2 GPUs and a batch size of 6, each GPU will receive 3 images.

Q04: I’ve seen both method working good. Probably the most common case is to standardize the data by subtracting the mean and dividing by the standard deviation. If you are training from scratch, you could just try different approaches and see which one works best. If you are using some pre-trained models, you should definitely use the same preprocessing as was used in the pre-training.

You shouldn’t normalize the mask, since the mask is used as the indices for your loss function.
Also, you cannot really normalize a LongTensor, so this would also create a type error.

regarding Question 2: Even if the GPU does not support float16 natively you can still train your model in float 16 and it will save some memory. however most (almost all) GPUs before the Pascal P100 do not support float16 natively. This will cause a small extra computational cost but if you do not have a time constraint this could be done as for most use cases a float16 precision is sufficient. It would be interesting to see if it performs faster than torch.utils.checkpoint.

Note: If you use any normalization layers (e.g. Batchnorm) you will have to cast them back to float32 as they don’t work with other precisions yet

I have 3x Titan V 12GB HBM2 GPUs, and a Quadro P6000 with 24GB GDDR5x.

These GPUs should all support float16 natively as they are build with the maxwell or volta architecure.

From what I’ve been able to gather, the Volta Titan V’s and the Pascal P100 and GP100 (server) GPUs support FP16.

The Quadro P6000 is a Pascal architecture, but I’m not certain about it’s FP16 support. Some threads talk about FP16 being enabled only for debug mode and runs at 1/64 performance. Another thread mentioned that it has FP16 1:1 support which was recently enabled via a software driver update along with the Titan XP GPU. So not quite sure about that. The P6000 card is good because it has a lot of memory, but the Titan V’s have nearly twice the throughput with half the memory (12GB) but faster HBM2 with a higher memory bandwidth.

The only drawback I see with the Titan V is if I want to train using really high-resolution images. A 2048x2048 RGB image just about fits into 12GB for a 3 layer U-Net model utilizing about 91% of the GPU resources.

1 Like

Hi @ptrblck

When I convert the generated mask to a Torch tensor, and convert it back, I see some corruption in the mask, with regular grid-like patterns appearing in the extracted mask as seen below. The problem gets exacerbated when using higher resolution input images at 2048x2048.

background mask ch0 corruption after tensor extraction using a resolution of 256x256

background mask ch0 corruption after tensor extraction using a resolution of 2048x2048

ch3 corruption after tensor extraction using a resolution of 256x256

ch4 corruption after tensor extraction using a resolution of 256x256

ch6 corruption after tensor extraction using a resolution of 256x256

This is a code snippet for the full test case. The mask before conversion to Torch tensors is fine, but it gets corrupted after extraction from the Torch tensor.


def test_mask_class_to_index_tensor_mapping(self):
    self.logger.info("generating mask")

    # mask parameters, which includes the number of channels to be included in the generated mask
    self.logger.debug("mask parameters: \n{}\n".format(pformat(self.mask_params)))

    # select a sample
    image_id = self.dataset_partition_params['train'][0]

    # generate mask
    mask = self.mask_generator.mask(id=image_id, height=3349, width=3391)
    self.logger.debug("generated mask type: {}, shape: {}".format(mask.dtype, mask.shape))

    # resize mask
    mask = resize(mask, 256, 256)

    # display mask
    display_mask(mask)

    # convert to torch type (c x h x w)
    mask = mask.transpose((2, 0, 1))

    # convert to torch tensor to type long
    mask = torch.from_numpy(mask).long()
    self.logger.info("mask tensor type: {}, mask  shape: {}".format(mask.dtype, mask.shape))

    # map target classes to tensor indices
    mask = torch.argmax(mask, dim=0)
    self.logger.info("mask shape after class to tensor index mapping: {}".format(mask.shape))

    """
    Now let' try to emulate the dataloader and retrieve the individual 
    mask channels.
    """

    # emulate adding an extra batch dimension by the dataloader
    labels_batch = torch.unsqueeze(mask, 0)
    self.logger.info("labels_batch shape after unsqueeze: {}".format(labels_batch.shape))

    # convert labels_batch back to target classes, for visual debug purposes
    # the unsqueezed image contains one color channel and pixel values indicating the classes
    n, h, w = labels_batch.shape
    tensor = torch.zeros(n, self.params.out_channels, h, w)
    tensor.scatter_(1, labels_batch.unsqueeze(1), 1)
    self.logger.info("label tensor shape: {}".format(tensor.shape))

    # convert the label tensor back to a numpy array
    label_mask = tensor.numpy()[0, :, :, :].transpose([1, 2, 0]) # * 9.0  # denormalize the mask values
    self.logger.info("label_mask shape: {}".format(label_mask.shape))
    display_mask(label_mask)

Q01: Is this a known issue?

Q02: What could be the reason for the mask data corruption after extracting it from the Torch tensor?

Q03: Is there some other operation other than mask = torch.argmax(mask, dim=0) that will perform the required class to index mapping without causing these artifacts?

That looks really strange.
Could you upload a single target image, so that I could debug this issue?

I’ve uploaded 3 tiff image files

  • generated mask, high resolution ch=11, height=3349, width=3391
  • resized mask, 256x256
  • output mask after extracting from tensor, showing corruption

Here is the link: https://www.dropbox.com/s/ygj0ofc22qjxu9b/mask-images.zip?dl=0

The high res mask image with 11 channels is a bit large, because it was a raw save without compression (550MB). The 256 mask is 2.9MB, the corrupted mask is 2.9MB.

Let me know if you can read it okay. I used tiff.imsave to save it. You should be able to load it in using tiff.imread.

Thanks for the links!
I played around a bit with the images and there seem to be a few issues.

  • The big image (3349x3391) has only ones and zeros in it, which is fine.
    However, the smaller image (256x256) seems to be created by linear interpolation or some other kind, which results in arbitrary floating values. I had to threshold the image to remove these outliers.

  • There seem to be an error or misunderstanding in the dataset. In both images there are some pixels with more than one class set. E.g. small image at pixel position [1, 55] class4 and class6 are both set.
    Is this an error or is the dataset a multi-label set?

Unfortunately, I couldn’t reproduce the images you provided.
Could it be a visualization error?
This code creates the same input images as without the transformation:

# Read file
img = tifffile.imread(image_path)

# Remove wrong values in small image (due to interpolation)
img[img <= 0.5] = 0
img[img > 0.5] = 1
mask = img.copy()

mask = mask.transpose((2, 0, 1))
mask = torch.from_numpy(mask).long()
mask = torch.argmax(mask, dim=0)
labels_batch = torch.unsqueeze(mask, 0)
n, h, w = labels_batch.shape
tensor = torch.zeros(n, 11, h, w)
tensor.scatter_(1, labels_batch.unsqueeze(1), 1)

mask_rev = tensor[0].permute(1, 2, 0).numpy()
display_mask(mask_rev)

The original image had 1s and 0s, but I missed the fact that when I temporarily rescaled the mask to 256x256, it got converted to float values.

This is the opencv code that was used to resize the mask.

def resize(img, height, width,
           interpolation=cv2.INTER_CUBIC):
    """
    Resize an image using OpenCV.
    Note that cv2.resize dsize is (height, width).

    :param img: Input image
    :param height: height
    :param width: width
    :param interpolation: cv2 interpolation type
    :return: resized image
    """

    return cv2.resize(img, (height, width), interpolation=interpolation)

Do you think this could be caused because of an overlap of classes, i.e. a car on a road, etc?

The dataset is a multi-class dataset. I won’t be trying to label every instance of a class for this,

You could change the interpolation to cv2.INTER_NEAREST for the target.
The source image can be interpolated with cv2.INTER_CUBIC, although a bilinear interpolation might even be sufficient.

Well, it depends on your goal. In the current scenario we only keep one class for each pixel position.
So if for example there is a car on the road, the “lower” class value will be kept by argmax.
For the pixel position [1, 55] only class4 will be stored, while class6 is lost.

If you care about both labels, we don’t have to save the indices and can keep the channel dimensions.
In your model you would have to predict these channels and apply some criterion like BCELoss on it.
What do you think?

What I will do is remove the interpolation default setting of cv2.INTER_CUBIC from the resize function, so that I can keep track of the type of interpolations being applied to float images and binary mask data types.

Some of the images in the DSTL dataset contain low resolution images (e,g, 136x134) which need to be resized to 3k x 3k, so cv2.INTER_CUBIC appeared to give good visual results, although it might get worse when you zoom into it. Later on, I can replace the resize operation for images using super-resolution using CNNs.

Yes, I do care about both labels.

What I’ve seen from other solutions for the same problem is they used an ensemble of U-Net’s trained on a binary segmentation task for a single class, e.g. roads, and then combine the output predictions of all the networks into a final colored mask.

However, I want to use a single U-Net network to train all 10 classes, even if the performance might be low.

Q01: If I switch to BCELoss, will it work on a U-Net model with a 3-ch input and 11-ch output, trained using a 11-channel target mask?

I ask this because most other U-Net implementations used BCELoss for binary segmentation for a single class, never saw it used for a multi-class segmentation problem.

What do you think?

Yes, your UNet will work and you won’t even need the background channel, since each class is now a binary prediction. So if no class for that particular pixel was detected, then it won’t get a class.
Which means we could use the target images (10 channels) directly.

I’ve created some UNet code a while ago and adapted it to your use case.
You can find the gist here.

Note that I just created one random input image and one random target image.
You have to add your data loading and processing into the code, but I hope it will be a good starter!

Hi @ptrblck

Q01: The documentation for the toTensor states that it will convert a numpy image in the range [0, 255] to [c, h, w] in the range [0.0, 1.0] (float).

So, if I use BCELoss, and I apply the Transform to the Dataset, should the transform be applied to both the image and mask? i.e should be mask stay as uint8 [0, 1] or be transformed to float [0.0, 1.0] ?

Q02: Conversely, just to clarify and be certain, if for another problem that used CrossEntropyLoss, I should leave the binary mask untouched as uint8 in the range [0, 1] and not apply the toTensor to the target mask?

Hi @ptrblck

The earlier mask corruption was fully attributed to something going wrong with opencv resize. I tried the same operation without resizing it, converted it to a torch tensor after doing the class to index mapping, and back. I could see the original mask shape without any issues or corruption. This would explain why I had issues training the NN earlier, for some of the early test runs. I’m only going to crop images of the required size and feed it into the NN.

Have you tried resizing your images with other frameworks as PIL or scikit-image?

To Q01:
If you use BCELoss, your target should have the save type as the output, so you would have to transform it into a FloatTensor. You should leave the ones and zeros as they are.

To Q02:
No, if you use another loss function like CrossEntropyLoss you would have to get the indices of the classes as we have done before.

No, but I’ll try it out.