Only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [1, 1, 256, 256]

Hi @ptrblck ,

I actually have a few questions about the color mapping example you have from this post.

In the code below you are just creating a dummy mask image? So since I already have the masked RGB image I wouldn’t need to have the code below, correct?

# Create dummy target image
nb_classes = 19 - 1 # 18 classes + background
idx = np.linspace(0., 1., nb_classes)
cmap = matplotlib.cm.get_cmap('viridis')
rgb = cmap(idx, bytes=True)[:, :3]  # Remove alpha value

h, w = 190, 100
rgb = rgb.repeat(1000, 0)
target = np.zeros((h*w, 3), dtype=np.uint8)
target[:rgb.shape[0]] = rgb
target = target.reshape(h, w, 3)

This is what I have so far for the function that I am going to include in the MyDataset class:

def convertTargetToMatrix(self, target):

        h = 1000

        w = 750

        mapping = {}#Creating a dictionary where key is the class id, and value is the color in mask

        mapping[0] = (0, 0, 0) #Class 0 = background

        mapping[1] = (128, 128, 128) #Class 1 = car

        mask = torch.empty(h, w, dtype=torch.long) #Creates an empty mask to be filled in below step

        #TODO:  Change each rgb value in color mask to its corresponding class index
        #Pseudo code below:
        for y in range(len(h)):
            for x in range(len(w)):
                rbgValue = getRGBOfTargetAtXY(x, y)
                if(rbgValue == mapping[0]):#pixel is background
                    mask[x][y] = 0
                elif(rbgValue == mapping[1]) :  #pixel is car  
                    mask[x][y] = 1


        return mask 
       

Can you let me know if I am going in the right direction or if I am doing something wrong here? The pseudo code I wrote is how I conceptually understand what’s happening (aka replacing rbg values with class ids), but if there is a specific function in torch or torchvision that does this better and more efficiently please let me know. I don’t know if the code below is doing that, as I don’t really understand it.

for k in mapping:
    # Get all indices for current class
    idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
    validx = (idx.sum(0) == 3)  # Check that all channels match
    mask[validx] = torch.tensor(mapping[k], dtype=torch.long)

specifically the line below, I don’t understand what its doing

idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))