Pytorch colormap gather operation

I am trying to implement the conversion of gray scale image to an RGB image through a colormap myself in PyTorch. I have:

  • A tensor gray_image which is a LongTensor of 512 x 512 with (integer) values between 0 and 255
  • The color_map, which is a FloatTensor of 256 x 3, where every row represents the RGB value for a specific gray value
  • An empty output file, which is a FloatTensor of 512 x 512 x 3

I fill the output tensor like this:

output = torch.zeros((512, 512 3) # Initialize empty
gray_image = (gray_image * 255).long() # Tensor values between 0 and 255 and LongTensor

for i in range(gray_image.shape[0]):
    for j in range(gray_image.shape[1]):
        output[i, j, :] = color_map[grads[i, j]]

This is quite slow and this must be possible in 1 or 2 lines using torch.gather(). I tried to do this myself, but did not succeed and found the documentation of torch.gather() to be quite complex. Can someone help me with this?

You can use this code

color_map = #Tensor of shape(256,3)
gray_image = (gray_image * 255).long() # Tensor values between 0 and 255 and LongTensor and shape of (512,512)
output = color_map[gray_image] #Tensor of shape (512,512,3)

Why does this work? The gray_image tensor is used as an index in the 0th dimension of color_map. The color_map tensor is accessed with this index and the new tensor created has the shape of gray_image inserted in the 0th dimension of output.

2 Likes