Error when saving UNET image predictions to folder: TypeError: Cannot handle this data type: (1, 1, 5), |u1

I have a UNET segmentation model with 5 classes and I am having trouble trying to save the image predictions.

Here is the code:

def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

My tensor preds has shape [1, 5, 100, 100] because I have 5 classes and the image size is 100x100.

I get the following error message:

Traceback (most recent call last):
  File "/REDACTED/REDACTED/REDACTED/pkg/anaconda/anaconda3-2021a/lib/python3.8/site-packages/PIL/Image.py", line 2749, in fromarray
    mode, rawmode = _fromarray_typemap[typekey]
KeyError: ((1, 1, 5), '|u1')

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "train.py", line 125, in <module>
    main()
  File "train.py", line 119, in main
    save_predictions_as_imgs(
  File "/home/REDACTED/REDACTED/unet-segmentation/utils.py", line 116, in save_predictions_as_imgs
    torchvision.utils.save_image(
  File "/REDACTED/REDACTED/REDACTED/pkg/anaconda/anaconda3-2021a/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/REDACTED/REDACTED/REDACTED/pkg/anaconda/anaconda3-2021a/lib/python3.8/site-packages/torchvision/utils.py", line 134, in save_image
    im = Image.fromarray(ndarr)
  File "/REDACTED/REDACTED/REDACTED/pkg/anaconda/anaconda3-2021a/lib/python3.8/site-packages/PIL/Image.py", line 2751, in fromarray
    raise TypeError("Cannot handle this data type: %s, %s" % typekey) from e
TypeError: Cannot handle this data type: (1, 1, 5), |u1

Any reason why this might be occurring? I tried the following suggestions but they do not seem to work:

Thank you!

Based on the error message it seems that PIL doesn’t recognize the used “image” format with 5 channels as seen here:

x = torch.empty(1, 1, 5).uniform_().byte()
PIL.Image.fromarray(x.numpy())
> TypeError: Cannot handle this data type: (1, 1, 5), |u1

so you would have to make sure the model output is using an expected shape (e.g. 3 or 1 channel(s)).

Hi, thank you for your response!

I am confused because I thought that the output of a UNET should be [batch_size, # of classes, height, width].

When I print preds.shape right after preds = (preds > 0.5).float(), I get
= Prediction shape: torch.Size([1, 5, 100, 100]). Am I doing something wrong?

Update. I realized I was calling my UNET model incorrectly.

Previously I had model = UNET(num_classes=5, in_channels=3, out_channels=5).to(DEVICE) but it should be model = UNET(num_classes=5, in_channels=3, out_channels=3).to(DEVICE).

Actually sorry the issue has not been resolved.

For a UNET multi-class segmentation model, the output should be [batch_size, # of classes, height, width].

So with 5 classes, a batch size of 1 and an image size of 100x100, the output should be [1, 5, 100, 100] which is correct.

The issue is trying to save this prediction directly as an image because PIL only recognizes 1 or 3 channels which makes sense. How can I save my model output as an image if the shape is technically incorrect for an image? Thank you.

You won’t be able to save it directly as an image, if it’s not a valid format.
In case you want to store the output directly, you could use torch.save. On the other hand, you might want to store the predictions for each class as a color-coded image.
In that case, you could get the predicted class indices via preds = torch.argmax(output, dim=1) (assuming output contains the logits/probabilities) and use a color mapping to create an RGB image (e.g. class0 -> RED, class1 -> BLUE, etc.).

1 Like

I am not sure
Please can you confirm the number of channels in UNET output. I have a doubt regarding that it would not be 5.
Is not it there is argmax which give a semantic map.
How you are defining the target here (mask are generally of one channel, so how you are reading target)

Thank you for your response as always!

I believe this post has an implementation of your predicted class indices suggestion: How to save multi-class segmentation prediction as image? - #2 by eqy.

I believe the idea is correct but I am having trouble getting it to run. If you could take a look at that then that would be really amazing! Thank you very much.