The size of tensor a (21) must match the size of tensor b (3) at non-singleton dimension 1

Hey guys!

I’m trying to train a neural network for multiclass semantic segmentation (FCN) using the CrossEntropyLoss function, but the following error occurs:

Preds: torch.Size([1, 21, 160, 240])
Target: torch.Size([1, 3, 160, 240])
The size of tensor a (21) must match the size of tensor b (3) at non-singleton dimension 1

As I am using a dataset with 21 classes (RGB 8 bit), my input should be correct. But the dimension of my target is different, this is a little confusing, I don’t know how to solve it.

Should I do some kind of transformation?

Any help or guidance on this will be greatly appreciated!

Part of the code where the error occurs:

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).squeeze(1)
            preds = nn.Softmax(dim=1)(model(x))
            preds = (preds > 0.5).float()
            y = y.permute(0, 3, 1, 2) 
            print(f'Preds: {preds.shape}')
            print(f'Target: {y.shape}')
            num_correct += (preds == y).sum()         ####ERROR#####
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"{num_correct}/{num_pixels} accuracy {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice Score : {dice_score/len(loader)}")
    model.train()

Hi Alisson!

The short answer is pretty much what the error message says: You’re
trying to perform an element-wise equality test between two tensors
whose dimensions don’t match. (If you were comparing two tensors
that had the same shape, you wouldn’t get this error.)

However, I think this is a symptom of some more general structural
things you are are doing wrong …

You don’t show your training code, and I’m a little surprised that you’re
not reporting similar errors with that, given what it looks like you’re
doing wrong.

Anyway, I would expect that you would be using targets (and
inputs) that have the same format for both training and your
accuracy evaluation.

The input to your model, x – the batch of images you wish to
perform semantic segmentation on – would typically have shape
[nBatch, nChannel, height, width].

In your case it appears that nBatch = 1. (A batch size of 1 is fine.)

It seems that nChannel = 3 – RGB? But you say that you are using
“RGB 8 bit” which would mean that you would have a single 8-bit
channel with palettized color – that is, you have a look-up table
(LUT) that translates the 8-bit value of each pixel into an RGB color
by using it as an index into the LUT. So it’s unclear what is going on
with your channel dimension.

height = 160 and width = 240, which is fine.

This line:

suggests that your images – or at least your target “images” – are
stored (individually) with shape [height, width, nChannel], so
you permute the dimensions to put them in the order preferred by
pytorch. This is also fine. (Why you have the .squeeze(1), I don’t
know.)

The core problem – and this is hard to know for sure because you
don’t show the code or shapes that you use for training – is that
your target (your y) is not formatted correctly for CrossEntropyLoss.

The output of your model (your preds) has shape
[nBatch = 1, nClass = 21, height = 160, width = 240].
This is fine. But you would then want the target (the annotated
class labels for the image) to have shape [nBatch, height, width],
with no nClass dimension. Each “pixel” in your target “image”
should be an integer class label that runs from 0 to
nClass - 1 = 20.

It is sometimes the case that target images have the class labels
color coded – that is, all of the class-3 pixels might be colored blue,
all of the class-7 pixels might be colored gray, and so on. If this is the
case, you would have to pre-process your target images so that
they use integer class labels, as described above.

Now, on to your accuracy calculation:

This is okay, but you don’t actually need the Softmax here.
preds = model (x) is simpler, and – for this purpose – equivalent.

This is incorrect (for the multi-class problem you say you have).
(This would make sense for binary segmentation.)

You want to convert the “soft,” probabilistic prediction made by
your model into a “hard” prediction for the single best class.

For this you want preds = preds.argmax (dim = 1). This
produces a batch of prediction “images” of shape
[nBatch, height, width] (no nClass dimension) where each
pixel is the integer class label for the class the model predicts
as being the most probable.

If your target (y) is also composed of integer class labels and has
shape [nBatch, height, width], (preds == y).sum() will
count the number of correct predictions.

Best.

K. Frank

Hi K. Frank, thanks so much for your notes.

Sorry, I forgot to enter the training code, it’s below:

import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNet
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs,
)
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.version.cuda}')

if torch.cuda.is_available() == True:
   DEVICE = "cuda"
   print('GPU')
   #torch.manual_seed(12)
   #torch.cuda.manual_seed(12)
   #np.random.seed(12)
else: 
   DEVICE = "cpu"
   print('CPU')

if(torch.backends.cudnn.is_available() == True):
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    print('cuDNN activated')
else:
    print('cuDNN not available!')

# Hyperparameters etc.
LEARNING_RATE = 3e-4
BATCH_SIZE = 1
NUM_EPOCHS = 10
NUM_WORKERS = 0
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "DATASET/train/imagens/"
TRAIN_MASK_DIR = "DATASET/train/mascaras/"
VAL_IMG_DIR = "DATASET/val/imagens/"
VAL_MASK_DIR = "DATASET/val/mascaras/"

#torch.cuda.empty_cache()

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.long().squeeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    
    model = UNet(21).to(DEVICE)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


    check_accuracy(val_loader, model, device=DEVICE)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader, model, folder="Results\\", device=DEVICE
        )


if __name__ == "__main__":
    main()

My code may seem meaningless at times, sorry about that, but I’m a beginner. I will try to explain some points better.

  1. I was also surprised that the training code didn’t report a bug. But if I apply argmax function to predictions and target to remove dimension 1, the error in check_accuracy function disappears. But then it causes a similar error in the training code.

  2. I forgot to comment, but I convert the images to RGB, because if I don’t, the colors of the masks are changed. When this happens, some different classes are given the same color. I don’t know why this happens. I’ll leave the dataset code here:

import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import time
import matplotlib
import matplotlib.pyplot as plt
import cv2
import torch

class MeuDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", ".png"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("RGB"), dtype=np.uint8)
        #mask = np.transpose(mask,(2, 0, 1))        
        #debug
        #print(f'Normal: {mask.shape}')
        #plt.imshow(mask)
        #plt.show()
        #mask = mask/classes
        #mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask
  1. Target formatting is really my biggest issue here, as you said. I’ve seen in other threads that we must convert RGB colors to class indices, but I have no idea how to do that. I’ve even tried it, but nothing has worked yet.

After changing the accuracy calculation as you suggested, I got the following error:

1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 160, 240, 3].

In this case, swapping the dimensions as required by PyTorch and removing the dimension corresponding to 3 would be sufficient?

Hi Alisson!

Let me make a number of comments before outlining how one
might convert RGB colors to integer class labels,

First, just to use a clear-cut word, let me call your ground-truth,
annotated (labeled) images “masks.” So your dataset consists
of a bunch of images that will be input to your model (for both
training and validation) each with a corresponding mask that is
used either to train your model or evaluate how well your trained
model is performing its segmentation task.

First, are your training and validation datasets in the same format?
Do your training and validation datasets come from one larger
combined dataset that was split (perhaps randomly) into the two
separate training and validation datasets?

What do your input images look like? Are they color images?
Grayscale? Do they look like sensible images that appear to the
eye to have structure that could be segmented?

What do your masks look like? Do they have blobs of different
colors in them that line up sensibly with the segments in the
corresponding input images?

Before processing, it sounds like your masks are some sort of color
images. What format are they in? Do they have three separate RGB
channels? Are they 32-bit RGB images (that might only be using
24 bits under the hood)? Are they 8-bit palettized images?

What format are your masks in after any processing that may occur
in your dataloader?

I would suggest that you print out the shapes of your input-image
and mask batches both immediately after they are delivered by
the dataloader, and then again immediately before they are fed
to your model and loss function. (The reason to print out the
shapes at these two different locations in your code is to verify
that your squeeze()s and permute()s are doing what you want.)

In the typical case (It doesn’t have to be this way, but it would be
typical.), a batch of input images would have shape
[nBatch, nChannel, height, width]. nChannel would typically
1 for grayscale images and 3 for color images. (nChannel could be
3 for grayscale images that happen to be formatted as color images.)

Note, if you are passing color images to your model, it will only work
if your input images (by the time they get passed to your model) are
multi-channel (almost certainly three-channel RGB) images. Your
model won’t work with 32-bit color nor 8-bit palettized color images.

The output of your model will typically have shape
[nBatch, nClass, height, width]. You should also print out
the shape of your model output images to verify this.

When used for training with CrossEntropyLoss your masks must
have shape [nBatch, height, width] (with no nClass nor nChannel
dimension), and the values of the mask pixels must be integer class
labels that run over 0 to nClass - 1. This is what CrossEntropyLoss
expects.

(It would be possible to have masks of a different shape or format
for your accuracy calculation – you don’t have to follow the
requirements of CrossEntropyLoss, but there is likely no benefit
to using a different format, and doing so would only confuse the issue.)

I would recommend making a simplified “test-bed” version of your
code – get rid of as much extraneous stuff as you can. Get rid of
the checkpoint stuff, cuda, amp, scaler, etc. I would keep the transforms
as they could modify the shape / format of the images and masks
that are being passed to your model, loss function, and accuracy
calculation, and that’s part of what you need to check.

Then try passing a single batch of train_loader images through
your model and loss function (printing out the shapes, as discussed
above) and try passing a single val_loader batch of images through
your model and accuracy calculation (again printing out shapes).

Make sure that everything lines up the way it should, and if it doesn’t
make sure you understand the details of what isn’t right, and what
you need to do, conceptually, to fix it.

Now to the question creating masks that consist of integer class
labels.

I will assume that your masks start out being some sort of color
images where each of a discrete set of colors (in your case,
presumably nClass = 21 different colors) indicates a specific
class.

Converting your masks to consist of integer class labels could be
done in your dataloader or in a separate processing step after
your dataloader. I would probably put this processing step in the
dataloader, but it doesn’t really matter.

You need to write the code that reads through all (or at least a
representative sample) of your masks and makes a table of all
the colors that show up in the pixels. (Conceptually, it doesn’t
matter whether the colors are encoded with three RGB channels
or as 32-bit or 8-bit colors – you just need to keep track of the
distinct colors that show up.) Each distinct color corresponds to
a class. Because you say that you have 21 classes, you had better
have no more that 21 distinct colors. For example, if you discover
that your training masks have 25 distinct colors, then your masks
are telling you that you have 25 classes, not 21, so you would have
to sort that out.

Now make a dictionary (consider it to be a look-up table) whose
keys are the 21 distinct colors and values are the integers 0
through 20. The assignment of integer class label to mask color
can be arbitrary. If your classes have some “meaning,” you are
welcome to assign integer class labels that somehow reflect that
meaning for your own convenience, but the choice won’t affect
how the model training progresses. (The class-label assignment
does have to be consistent across your dataset – both your training
and validation data jointly.)

(As an aside, for this kind of semantic segmentation, we often think
of an image being segmented as having 'foreground" objects that
form the classes and an “uninteresting” background. For training,
however, the background pixels still form a class, and are treated
no differently than the “foreground” pixels. For your convenience
(It doesn’t affect the training.), you might choose to use 0 as the
integer class label for the background pixels.)

Now for each mask read through the pixels in the mask tensor and
create a new mask tensor of dtype = torch.int64 (the same as
torch.long) and typically of shape [height, width]. As you write
your new mask tensor on a pixel-by-pixel basis, translate the
corresponding color-valued pixel from your original mask into the
proper integer class label by passing it through the look-up table
you built.

Good luck.

K. Frank