I am trying to implement the conversion of gray scale image to an RGB image through a colormap myself in PyTorch. I have:
- A tensor gray_image which is a LongTensor of 512 x 512 with (integer) values between 0 and 255
- The color_map, which is a FloatTensor of 256 x 3, where every row represents the RGB value for a specific gray value
- An empty output file, which is a FloatTensor of 512 x 512 x 3
I fill the output tensor like this:
output = torch.zeros((512, 512 3) # Initialize empty
gray_image = (gray_image * 255).long() # Tensor values between 0 and 255 and LongTensor
for i in range(gray_image.shape[0]):
for j in range(gray_image.shape[1]):
output[i, j, :] = color_map[grads[i, j]]
This is quite slow and this must be possible in 1 or 2 lines using torch.gather()
. I tried to do this myself, but did not succeed and found the documentation of torch.gather()
to be quite complex. Can someone help me with this?