Hi @ptrblck ,
I actually have a few questions about the color mapping example you have from this post.
In the code below you are just creating a dummy mask image? So since I already have the masked RGB image I wouldn’t need to have the code below, correct?
# Create dummy target image
nb_classes = 19 - 1 # 18 classes + background
idx = np.linspace(0., 1., nb_classes)
cmap = matplotlib.cm.get_cmap('viridis')
rgb = cmap(idx, bytes=True)[:, :3] # Remove alpha value
h, w = 190, 100
rgb = rgb.repeat(1000, 0)
target = np.zeros((h*w, 3), dtype=np.uint8)
target[:rgb.shape[0]] = rgb
target = target.reshape(h, w, 3)
This is what I have so far for the function that I am going to include in the MyDataset class:
def convertTargetToMatrix(self, target):
h = 1000
w = 750
mapping = {}#Creating a dictionary where key is the class id, and value is the color in mask
mapping[0] = (0, 0, 0) #Class 0 = background
mapping[1] = (128, 128, 128) #Class 1 = car
mask = torch.empty(h, w, dtype=torch.long) #Creates an empty mask to be filled in below step
#TODO: Change each rgb value in color mask to its corresponding class index
#Pseudo code below:
for y in range(len(h)):
for x in range(len(w)):
rbgValue = getRGBOfTargetAtXY(x, y)
if(rbgValue == mapping[0]):#pixel is background
mask[x][y] = 0
elif(rbgValue == mapping[1]) : #pixel is car
mask[x][y] = 1
return mask
Can you let me know if I am going in the right direction or if I am doing something wrong here? The pseudo code I wrote is how I conceptually understand what’s happening (aka replacing rbg values with class ids), but if there is a specific function in torch or torchvision that does this better and more efficiently please let me know. I don’t know if the code below is doing that, as I don’t really understand it.
for k in mapping:
# Get all indices for current class
idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
validx = (idx.sum(0) == 3) # Check that all channels match
mask[validx] = torch.tensor(mapping[k], dtype=torch.long)
specifically the line below, I don’t understand what its doing
idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))