Perform a 2D 'source pixel' lookup on a batch of 17 RGB images


I have a set of 3 channel (RGB) 512x512 images with a batch size of 17, stored in a tensor of shape (17, 3, 512, 512).

I also have a series of 2D lookup co-ordinates in a tensor of shape (17, 512, 512, 2), where that final 2 is the XY coordinate of the pixel to look up a source pixel in the original image, allowing me to rearrange the pixels of the 17 source images into 17 new resulting images.

Thanks to some help from this forum, I had this ‘2D pixel lookup’ working when I had a single 2D source image, which also was grayscale, so my image tensor was just (512, 512) and the ‘2d lookup co-ordinates’ tensor was (512, 512, 2). The line that made that work was:

new_images = orig_images[..., lookup_coords[:, :, 0], lookup_coords[:, :, 1]]

But now I have 3 channels in my image and a batch size of 17, this line produces a resulting shape of (17, 3, 17, 512, 2) instead of (17, 3, 512, 512) - which matches the input textures - as I wanted.

I’ve messed with that line a lot, and tried squeezing and unsqueezing various items, but I just can’t get the right output shape. Can anyone help? Thanks if you can.

p.s. In case the above wasn’t clear, the 17 input images each have their own unique 2D lookup coordinate maps

Could you provide a (slow) reference implementation we could use to verify a potentially better approach?

1 Like

Yes, good plan. In fact, the ‘slow’ reference code I wrote isn’t even slow, it just has an iterator over each image in the batch. With only 17 images in this case, I doubt there will be any speed loss, so I think this issue is fixed now, thanks.

For anyone wanting to see what I did, I added this to deal with each image in the batch separately:

    output_images = torch.zeros_like(batch_of_images)

    batch_size = batch_of_images.size(0)
    for batch_index in range(batch_size):

        orig_image  = batch_of_images[batch_index]
        lookup_coords = batch_of_coords[batch_index]

…and that brought my (17, 3, 512, 512) input images tensor down to just (3, 512, 512), which actually got this line of code working to perform the 2D pixel lookup:

output_images[batch_index] = orig_image[..., lookup_coords[..., 0], lookup_coords[..., 1]]

I could maybe optimize this code by getting rid of the zeros_like() bit, and just concatenating the result images, but I’ll leave that for another day when I know how to work with Tensors better.