How to save multi-class segmentation prediction as image?

Hi,

My multi-class UNET model output is of the following shape: [1, 6, 100, 100] which is expected because the batch size is 1, I have 6 classes, and the image size is 100x100.

How can I save a prediction as an image which contains all 6 classes using torchvision.utils. save_image? The tensor in save_image expects either 1 or 3 channels but my model technically outputs 6 channels for the 6 classes.

Thank you.

There might be a utility function somewhere that does this, but you can write your own function like:

    class_to_color = [torch.tensor([1.0, 0.0, 0.0]), ...]
    output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
    for class_idx, color in enumerate(class_to_color):
        mask = out[:,class_idx,:,:] == torch.max(out, dim=1)
        mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
        curr_color = color.reshape(1, 3, 1, 1)
        segment = mask*color # should have shape 1, 3, 100, 100
        output += segment

Thank you for your response. I get the following error when executing that code:

mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
AttributeError: 'bool' object has no attribute 'unsqueeze'

since mask is a bool from mask = out[:,class_idx,:,:] == torch.max(out, dim=1).

What are we trying to accomplish from this line? How should we rewrite it to make mask a tensor and not a bool? Thank you!

Ah, that’s because max actually gives back the values and the indices; changing it to mask = out[:,class_idx,:,:] == torch.max(out, dim=1)[0] should fix that issue.

    class_to_color = [torch.tensor([1.0, 0.0, 0.0]), ...]
    output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
    for class_idx, color in enumerate(class_to_color):
        mask = out[:,class_idx,:,:] == torch.max(out, dim=1)[0]
        mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
        curr_color = color.reshape(1, 3, 1, 1)
        segment = mask*color # should have shape 1, 3, 100, 100
        output += segment

Thank you. That solved that issue but now I am running into a new one.

segment = mask*color # should have shape 1, 3, 100, 100
RuntimeError: The size of tensor a (100) must match the size of tensor b (3) at non-singleton dimension 3

Also, the shape of mask after mask = mask.unsqueeze(1) is [1, 3, 100, 100] instead of [1, 1, 100, 100] but I am not sure why.

I cannot reproduce that issue.
Can you check the shape of your output is expected?

>>> out = torch.randn(1, 6, 100, 100)
>>> mask = out[:,0,:,:] == torch.max(out, dim=1)[0]
>>> mask.shape
torch.Size([1, 100, 100])
>>> mask.unsqueeze(1).shape
torch.Size([1, 1, 100, 100])

Yes the shape is correct.

Here is the code:

            pred = torch.sigmoid(model(x))
            out = (pred > 0.5).float()
            print(f"out shape: {out.shape}\n")
            class_to_color = [torch.tensor([0.0, 0.0, 0.0]), torch.tensor([10, 133, 1]), torch.tensor([14, 1, 133]),  torch.tensor([33, 255, 1]), torch.tensor([243, 5, 247]), torch.tensor([(255, 0, 0)])]
            output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
            for class_idx, color in enumerate(class_to_color):
                mask = out[:,class_idx,:,:] == torch.max(out, dim=1)[0]
                print(f"{mask}\n")
                mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
                print(f"mask shape: {mask.shape}\n")
                curr_color = color.reshape(1, 3, 1, 1)
                print(f"color shape: {color.shape}\n")
                segment = mask*color # should have shape 1, 3, 100, 100
                output += segment
            torchvision.utils.save_image(output, f"{folder}/pred_{idx}.png")

Here is the corresponding output:

out shape: torch.Size([1, 6, 100, 100])

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]], device='cuda:0')

mask shape: torch.Size([1, 1, 100, 100])

color shape: torch.Size([3])

And the error:

segment = mask*color # should have shape 1, 3, 100, 100
RuntimeError: The size of tensor a (100) must match the size of tensor b (3) at non-singleton dimension 3

Ah, curr_color should be used instead:

    class_to_color = [torch.tensor([1.0, 0.0, 0.0]), ...]
    output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
    for class_idx, color in enumerate(class_to_color):
        mask = out[:,class_idx,:,:] == torch.max(out, dim=1)[0]
        mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
        curr_color = color.reshape(1, 3, 1, 1)
        segment = mask*curr_color # should have shape 1, 3, 100, 100
        output += segment

However, note that you should make sure the color formatting is consistent (e.g., either floating point values between 0.0 and 1.0 or integers between 0 and 255).

Thank you! That solved that issue but now I am getting the following issue :face_with_monocle:

segment = mask*curr_color # should have shape 1, 3, 100, 100
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Did one of mask or curr_color get implicitly allocated on the CPU? I thought everything gets allocated on the GPU by default.

Nope, things are allocated on the CPU by default. You can simply add device='cuda' to the torch.tensor(...) calls to fix this.

1 Like

Works like a charm. I also had to add output to the GPU.

Thank you very much I really appreciate your prompt and helpful responses! :slight_smile:

1 Like

Hi @eqy, related to this, do you know how I can use this approach of iterating class by class so that I can compute the accuracy for each class?

This is what I currently have. check_accuracy gets executed for each epoch during training.

def check_accuracy(loader, model, device="cuda"):
    model.eval()

    with torch.no_grad():
        for x, y, _ in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = model(x)
            num_classes = preds.shape[1]
            for class_idx in range(num_classes):
                curr_class = (preds[:,class_idx,:,:] == torch.max(preds, dim=1)[0]).float()
                curr_class = curr_class.unsqueeze(1)
                num_correct = (curr_class == y).float()
                class_accuracy = num_correct.sum()/num_correct.numel()
                class_accuracy = class_accuracy*100
                print(f"{class_accuracy}% accuracy for class {class_idx}\n")
    model.train()

Is this the correct way to do it?

Thank you very much!

I don’t see anything obviously wrong with this immediately, but I also don’t recall the exact data layout and format used here. The best practice would be to write some predictable test cases if you aren’t sure about the correctness (e.g., try generating synthetic model outputs and labels like an entire image with just one class, half and half). You could also check that random outputs get roughly the accuracy of random guessing, etc…

1 Like

Hi,

I just have a follow up question. Since it’s been almost 2 months, here is a refresher.

We wanted to visualize the output of a UNET multi-class segmentation model. The model outputs a tensor of the following form: [batch size, number of classes, h, w] where in each pixel of the tensor, it is an index in [0, number of classes). Obviously if the number of classes isn’t 1 or 3, then this tensor will not be a valid image (also by the pixel values too). So what we do is we map each class index to its corresponding color and then append each class to an output of the form [1, 3, h, w] so that way we can save output as an image in order to visualize it.

for idx, (x, y) in enumerate(loader):
    x = x.to(device=device)
    with torch.no_grad():
        pred = torch.sigmoid(model(x))
        out = model(x)
        class_to_color = [torch.tensor([1.0, 0.0, 0.0]), ...]
        output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
        for class_idx, color in enumerate(class_to_color):
            mask = out[:,class_idx,:,:] == torch.max(out, dim=1)[0]
            mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
            curr_color = color.reshape(1, 3, 1, 1)
            segment = mask*curr_color # should have shape 1, 3, 100, 100
            output += segment

This solution works perfectly fine for batch sizes of 1. But I tried a batch size of 8 and then mask has the shape [8, 1, 100, 100] and thus something like segment = mask*curr_color won’t work. Do you know how we can modify this so that way it can work for batch sizes that aren’t 1? Thank you!

1 Like

Can you describe in more detail what the problem is? I’m not sure I understand where having a batch size larger than one breaks things, as the max is across the class dimensions and the remaining computations are elementwise or broadcasting across all dimensions.
Out of the box, the code still runs on my end:

>>> import torch
>>> out = torch.randn(8, 3, 100, 100)
>>> mask = out[:, 1, :, :] == torch.max(out, dim=1)[0]
>>> mask = mask.unsqueeze(1)
>>> curr_color = torch.randn(1, 3, 1, 1)
>>> segment = mask*curr_color
>>> segment.shape
torch.Size([8, 3, 100, 100])
>>>

Actually what I got was the following error, sorry for not being specific:

output += segment
RuntimeError: output with shape [1, 3, 1000, 1000] doesn't match the broadcast shape [8, 3, 1000, 1000]

I was able to fix this issue by changing

output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)

to

output = torch.zeros(out.shape[0], 3, ... ) #this ensures that regardless of the batch size, output image shape will match up

But now my output image consists of a grid of size batch (8)… So output contains 8 images, do you know how I can iterate through the first dimension and get each of the 8 images from it? Thank you very much!

Yes, you would need to increase the output size.

You can access at 1,3,1000,1000 image by using an indexing operation

image0 = output[0,:,:,:]
image1 = output[1,:,:,:]
...

Saving the images can also be done in this way (a sketch)

for i in range(8):
    curr_image = output[i,:,:,:]
    save_image(curr_image, i,)
    ....
1 Like

Works like a charm! Thank you very much!