RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

Issue: RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

Problem Statement: I have an image and a pixel of the image can belong to only(either) one of Band5','Band6', 'Band7' (see below for details). Hence, I have a pytorch multi-class problem but I am unable to understand how to set the targets which needs to be in form [batch, w, h]

My dataloader return two values:

x = chips.loc[:, :, :, self.input_bands]     
y = chips.loc[:, :, :, self.output_bands]        
x = x.transpose('chip','channel','x','y')
y_ohe = y.transpose('chip','channel','x','y')
return x, y_ohe

Also, I have defined:

input_bands = ['Band1','Band2', 'Band3', 'Band3', 'Band4']  # input classes
output_bands = ['Band5','Band6', 'Band7'] #target classes

model = ModelName(num_classes = 3, depth=default_depth, in_channels=5, merge_mode='concat').to(device)
loss_new = nn.CrossEntropyLoss()

In my training function:

        #get values from dataloader
        X = normalize_zero_to_one(X) #input
        y = normalize_zero_to_one(y) #target

        images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]
        masks = Variable(torch.from_numpy(y)).to(device) 
        optim.zero_grad()        
        outputs = model(images) 

        loss = loss_new(outputs, masks) # (preds, target)
        loss.backward()         
        optim.step() # Update weights  

I know the the target (here masks) should be [batch_size, w, h]. However, it is currently [batch_size, channels, w, h].

I read a lot of posts including 1, 2 and they say the target should only contain the target class indices. I don’t understand how can I concatenate indices of three classes and still set target as [batch_size, w, h].

Right now, I get the error:

RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

To the best of my understanding, I don’t need to do any one hot encoding. Similar errors and explanation I found on the internet are here:’

Any help will be appreciated! Thank you.

@ptrblck I have seen you have given most responses to similar questions in the past. I have even linked them in the references. You mention that how the target shape should look like but I didn’t get how to achieve that change in shape.

I’d really appreciate if you can help.

Thank you!

Assuming your current target is one-hot encoded in the channel dimension, i.e. it uses a 1 for the “active” class in that channel while all other channels contain zeros, you could use:

target = torch.argmax(target, dim=1)

to create the target with the expected class indices.

If that doesn’t work for you, could you post an example target with its values?

Also, don’t use Variables anymore, as they are deprecated since PyTorch 0.4.

Hi @ptrblck

Thank you ! Two questions:

  1. In one of your past posts here :

No, for multi-class classification (one target class for each sample), the targets should hold the class indices. Other frameworks often use one-hot encoded target vectors, which is not necessary in PyTorch. Have a look at the docs for more information.

I was wondering if the vectors need to be one-hot encoded or not?

  1. When you say Assuming your current target is one-hot encoded in the channel dimension, i.e. it uses a 1 for the “active” class in that channel while all other channels contain zeros, how can I encode the data for three different classes.

Currently my target (masks in the code snippet in question) looks like this:

#print(masks)
tensor([[[[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 1.],
…,
[1., 1., 0., …, 0., 0., 1.],
[1., 0., 1., …, 0., 0., 0.],
[1., 1., 0., …, 0., 0., 0.]],
[[1., 1., 1., …, 0., 0., 0.],
[1., 1., 1., …, 0., 0., 0.],
[1., 1., 1., …, 0., 0., 0.],
…,
[0., 0., 0., …, 1., 1., 0.],
[0., 0., 0., …, 1., 1., 1.],
[0., 0., 1., …, 1., 1., 1.]]]])

Also, the shape of X , y and 'masks` respectively are as below:

(1, 5, 256, 256)
(1, 3, 256, 256)
torch.Size([1, 3, 256, 256])

For a multi-class classification or segmentation, the target should not be one-hot encoded.
Based on your masks output, it seems that you might be dealing with a multi-label segmentation, where each pixel might belong to zero, one, or multiple classes.
Some pixels have no active class (e.g. row0, last column). Is that your use case?

Sorry if my code may be wrong somwhere, but in my case, each pixel can only belong to one class i.e. to either of one bands of the output bands (shown in code snippet).

I also edited my previous code to show more data with shapes.

Ah OK, in that case each pixel should have one valid 1 value in the channel dimension and you can use:

target = torch.argmax(target, dim=1)

to create the valid target tensor for nn.CrossEntropyLoss. :slight_smile:

Ok! But currently my target shape is (1, 3, 256, 256), isn’t it expecting something like (1, 1, 256, 256) for me to apply target = torch.argmax(target, dim=1)?

I’ll try the code snippet in the meanwhile.

May be I just don’t understand how `

target = torch.argmax(target, dim=1)

`

encoded the data for all the 3 output channels!

@ptrblck Thank you for the help! I see the loop works for the loss but how would we get the prediction values for each class from it?

Edit 1:

My current test code set up looks like this but how to get the right prediction:

        X = normalize_zero_to_one(X)
        y = normalize_zero_to_one(y)
        
        images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]
        masks = Variable(torch.from_numpy(y)).to(device)
        
        masks = torch.argmax(masks, dim=1)
        
        outputs = model(images)            
        loss = loss_new(outputs, masks)          
        output_sigmoid = torch.sigmoid(outputs)
        preds = output_sigmoid > output_sigmoid.cpu().numpy().mean()

I don’t think I’ll need sigmoid here, isn’t it?

Assuming that you are working with 3 classes (based on the last edit of the post with the shape information), your model output should have the shape [batch_size, nb_classes=3, height, width].
To get the predicted class indices, you can use the same method:

preds = torch.argmax(output, dim=1)

Thanks @ptrblck
Yes, I do have the shape shape [batch_size, nb_classes=3, height, width].which seems I am in the right direction.

preds = torch.argmax(output, dim=1) will give me the prediction values (or probabilties). Will I need to use a softmax or a sigmoid for thresholding?

This will give you the predicted class indices, no probabilities.
output should contain the raw logits. If you need to see probabilities, you should apply softmax on it.
However, you don’t need to use the softmax to get the most likely predicted class, as argmax(logits, 1) will yield the same result as argmax(softmax(logtis, 1) 1).

A threshold is not necessary for a multi-class prediction.

Thanks @ptrblck. One more question (sorry for troubling. I am working with multi-class problem for the first time)

I am looking at the output:

        X = normalize_zero_to_one(X)
        y = normalize_zero_to_one(y)
        
        images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]
        masks = Variable(torch.from_numpy(y)).to(device)
        
        masks = torch.argmax(masks, dim=1)
        outputs = model(images)            
        loss = loss_new(outputs, masks)          
        
        preds = torch.argmax(outputs, dim=1)
        
        print(preds.shape)
        print(preds)

torch.Size([1, 256, 256])
tensor([[[1, 1, 1, …, 1, 1, 1],
[1, 1, 1, …, 1, 1, 1],
[1, 1, 1, …, 1, 1, 0],
…,
[1, 1, 1, …, 1, 1, 1],
[1, 1, 1, …, 1, 1, 0],
[1, 0, 0, …, 1, 1, 0]]])

Can you help me understand how to read this output? When I see the first value as 1, how do I read which output band is it talking about

P.S. Note, for certain simplicity, I changed my output classes to 2 classes ( Band5','Band6) and got output shape [batch_size, nb_classes=2, height, width]. Print result in the answer reflects the same.

The preds tensor represents the predicted class index for each pixel.
I.e. you could interpret the tensor as:

tensor([[[class1, class1, class1, …, class1, class1, class1],
          ...
         [class1, class1, class1, …, class1, class1, class0],
          ...,

For the original problem with 3 classes, you should see the values [0, 1, 2] in the pred tensor.

Ah! Interesting. So, if I want to extract the pixel ‘values’ of a certain class where it is predicted, say class 1, I should extract the idices where I see class 1 as 1 and I can use them to plot the values.

I hope I got it right this time!

Yes, you could get the pixel locations via:

class0_loc = (preds==0).nonzero()
class1_loc = (preds==1).nonzero()

You can also plot the prediction directly e.g. using matplotlib via:

plt.imshow(preds[0])

Thanks @ptrblck That was really helpful!

@ptrblck Thank you for your help in working with Cross Entropy Loss. However, the predictions (for three classes) seem to be too off.

When I use the code below:

            print("shape of model input images:", images.shape)
            print("shape of masks_input:", masks_input.shape)
            print('masks input:')
            print(masks_input)
            outputs = model(images)  #model doesn't have softmax and images are normalized
            masks = torch.argmax(masks_input, dim=1)
            loss = loss_new(outputs, masks) 
            preds = torch.argmax(outputs, dim=1)
            print("shape of preds:", preds.shape)
            print('preds output')
            print(preds)

I get the following pred output. Basically, 95% of it is being predicted of Class 1 -

shape of model input images: torch.Size([1, 5, 256, 256])
shape of masks_input: torch.Size([1, 3, 256, 256])
masks input:
tensor([[[[0., 0., 0.,  ..., 1., 1., 1.],
          [0., 0., 0.,  ..., 1., 1., 1.],
          [0., 0., 0.,  ..., 1., 1., 0.],
          ...,
          [0., 0., 1.,  ..., 0., 0., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 1.],
          ...,
          [1., 1., 0.,  ..., 0., 0., 1.],
          [1., 0., 1.,  ..., 0., 0., 0.],
          [1., 1., 0.,  ..., 0., 0., 0.]],

         [[1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 1., 1., 0.],
          [0., 0., 0.,  ..., 1., 1., 1.],
          [0., 0., 1.,  ..., 1., 1., 1.]]]])
shape of preds: torch.Size([1, 256, 256])
preds output
tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 2],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 2],
         [0, 0, 0,  ..., 0, 0, 0]]])

Your model might be overfitting to the majority class.
Could you check the class distribution in your masks, i.e. how many pixels belong to which class?
If you have a lot of background (class0), your model might only lean this class and you could use e.g. focal loss to counter it.