Only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [1, 1, 256, 256]

I’m dealing with multi-class where each pixel can be assigned to a single class only. I had 4 different masks for each image where each mask represents one class. So I had 4 binary masks. What I have proceeded to do now is that I converted the binary masks to have values 0/1 for mask 1, 0/2 for mask 2 and so one and then added all of them to have a mask something like


So how can I do color mapping for each class now ?

1 Like

The target looks correct.
You could most likely create it by using torch.argmax(target, dim=1).

What kind of color mapping do you need?

I’m sorry, I didn’t get

Is this for creating the mapping ?

I need simple color mapping, 4 colors representing each class.

Do you need this to restore the original target image?
If so, you could create a mapping with e.g. a dict and index it with your target tensor.
Note that your current target tensor is suitable to be passed into nn.CrossEntropyLoss.

I’m using nn.CrossEntropyLoss only. I need to color mapping to visualize if the segmentation is done correctly.

In this case indexing should work:

cmap = torch.tensor([[255, 0, 0],
                     [0, 255, 0],
                     [0, 0, 255]])

target = torch.randint(0, 3, (1, 10, 10))
res = cmap[target]

This seems like the index mapping can work.

Also, correct me if I’m wrong. I’m breaking down my procedure -

  1. I created a custom dataset that fetches the images and the masks (add the masks as said above)
  2. The dataloader is iterating over these.
  3. Im using DeepLabV3 pretrained model with the head changed to 4 (total classes for me)
  4. Using nn.CrossEntropyLoss
  5. The training doesn’t seem to be promising at all.

Is there anything I should be doing extra for semantic segmentation?

The procedure looks correct.
I would recommend to try to overfit a small data sample (e.g. just 10 data samples) to verify the training procedure does not contain any hidden errors.

I just did a 5 epochs. I have about 250 images.
I’m a bit confused now. The shape of my outputs from the model is torch.Size([5, 4, 224, 224]). Batch size is 5 and there are 4 masks.

Now to get the prediction how do I combine these 4 separate outputs to a single mask ?

pred = torch.argmax(output, 1) would give you the predicted class indices, which you could then pass to your mapping to get the corresponding colors.

Thank you. This makes sense to use index with the class for mapping!

@ptrblck, turns out my masks weren’t binary images. So what I did was I changed the pixel value to class value wherever there was a non-zero value. Here 2 is one of the classes.

m2 =
m2 = np.asarray(m2)
m2 = np.where(m2>0, 2, m2)
m2_tensor = self.to_tensor(m2)

Is this a right approach ?

Hello @ptrblck, Thank you. I also faced the same problem and found your solution.
However, my code showing some creepy characters. Like, when I am writing this code,

def validation_step(self, batch, batch_nb):
         x, y = batch
        y_hat = self.forward(x)
        y = y.squeeze(axis = 1)

does not print and work for y = y.squeeze(axis = 1)

But, when I am changing a little bit, like, storing the value in another tensor it showed and worked.

y_label = y.squeeze(axis = 1)

Could you tell me, why the code is not working for the first one y = y.squeeze(axis = 1)?

I don’t know, why this should be the case.
Could you post an executable code snippet, which shows this behavior?

Thank you for your comment. I am not sure, what actually executable code snippet!

I gave a post regarding this issue in here.

Let me know your thoughts.

@ptrblck Hey! I am training UNET and I got the same error above. I tried your torch.squeeze(1) and it gave me another error IndexError: Target -2 is out of bounds. and its coming from my loss function. Here is my loss function and training step:

class LogNLLLoss(_WeightedLoss):
    _constants_ = ['weight', 'reduction', 'ignore_index']

    def __init__(self, weight=None, size_average=None, reduce=None, reduction=None,
        super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index

    def forward(self, y_input, y_target):
        y_input = torch.log(y_input + EPSILON)
        return cross_entropy(y_input, y_target, weight=self.weight,

and the training function:

def train_epoch(self, dataloader, dice_loss):
        epoch_running_loss = 0
        for batch_idx, (x_batch, y_batch) in enumerate(dataloader):
            x_batch =
            y_batch =
            y_out =
            training_loss = self.loss(y_out, y_batch.squeeze(1))
            train_dice = dice_loss(y_out, y_batch)
            epoch_running_loss += training_loss.item()
        return (epoch_running_loss/len(dataloader)), train_dice

Is there any way I can fix this error, what am I doing wrong?

Thanks in advance

Based on the errors you are seeing I guess the y_batch.squeeze(1) operation might have returned a tensor in the right shape, but the tensor itself contains invalid indices (-2 in particular).
Assuming you are using F.cross_entropy, the target should contain indices in the range [0, nb_classes-1], so you would have to check why the -2 is there.

@ptrblck I have 2 classes which are white and black. Do you think it could be from that? Can I print out the shape of anything to check where this error is coming from?

You can print the shape anywhere in your code via print(tensor.shape) or print(tensor.size()).
I think the issue is that your target contains invalid values, while the shape might be alright.

Ah you are right, looks like some of my pixel were -2 so I had to fix that. I’m encountering an issue though, everything is working fine except for the dice loss. My dice loss is coming out as:

Step - 200 [Train Loss - 0.027800448606722056] [Dice Coeff - 0.9999991655349731]
Step - 400 [Train Loss - 0.018787917831214144] [Dice Coeff - 0.9999994039535522]
Step - 600 [Train Loss - 0.014162900804076345] [Dice Coeff - 0.9999995231628418]
Step - 800 [Train Loss - 0.0113546519659576] [Dice Coeff - 0.9999995827674866]
Step - 1000 [Train Loss - 0.009469814540469088] [Dice Coeff - 0.9999996423721313]
Step - 1200 [Train Loss - 0.008117731069796718] [Dice Coeff - 0.9999996423721313]
Step - 1400 [Train Loss - 0.007100731051427179] [Dice Coeff - 0.9999997019767761]

Which is weird, here is my dice loss implementation and train function:

class SoftDiceLoss(nn.Module):
    def __init__(self):
        super(SoftDiceLoss, self).__init__()
    def forward(self, pred, target):
       smooth = 1.
       iflat = pred.contiguous().view(-1)
       tflat = target.contiguous().view(-1)
       intersection = (iflat * tflat).sum()
       A_sum = torch.sum(iflat * iflat)
       B_sum = torch.sum(tflat * tflat)
       return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )

def train_epoch(self, dataloader, dice_loss):
        epoch_running_loss = 0
        for batch_idx, (x_batch, y_batch) in enumerate(dataloader):
            x_batch =
            y_batch =
            y_out =
            training_loss = self.loss(y_out, y_batch)
            train_dice = dice_loss(y_out, y_batch)
            if batch_idx % 200 == 0 and batch_idx != 0:
                print(f"Step - {batch_idx} [Train Loss - {epoch_running_loss/batch_idx}] [Dice Coeff - {train_dice.item()}]")
            epoch_running_loss += training_loss.item()
        return (epoch_running_loss/len(dataloader)), train_dice
    def train_unet(self, train_loader, val_loader, n_epochs, dice_metric):
        min_loss = np.inf
        train_time = time.time()
        dice_metric =
        logs = {}
        for epoch in range(1, n_epochs+1):
            train_loss, train_dice = self.train_epoch(train_loader, dice_metric)
            val_loss, val_dice = self.val_epoch(val_loader, dice_metric)
            logs = {'epoch': epoch,
                    'time': epoch_end - train_start,
                    'train_loss': train_loss,
                    'validation_loss': val_loss,
                    'train_dice': trian_dice,
                    'validation_dice': val_dice
            print("-" * 20)
            print(f"Epoch - {logs['epoch']} | Time Elapsed - {logs['time']} | Training Loss - {logs['train_loss']} | Train Dice Coeff - {logs['train_dice']}") 
            print(f"Validation Loss - {logs['validation_loss']} | Validation Dice - {logs['validation_dice']}")

If you could point me in the direction of what is wrong that would be great! Thanks so much.