Index_select where the indices are 2D rather than 1D?

Hi!

I’m trying to make my own nn.Module where the input and outputs are both 2D grayscale images. Each 2D output pixel (ignoring batches for now) will copy a pixel in the input 2D image, according to a provided lookup tensor of 2D co-ordinates. So, for example, this snippet would rotate the corner pixels (the ones with values 40 to 43) of the input 4x4 image clockwise:

        input_image = torch.Tensor([
            [40,  1,  2, 41],
            [ 3,  4,  5,  6],
            [ 7,  8,  9,  1],
            [42,  2,  3, 43]
        ])

        coords_to_look_up = torch.Tensor([
            [[3, 0], [0, 1], [0, 2], [0, 0]],
            [[1, 0], [1, 1], [1, 2], [1, 3]],
            [[2, 0], [2, 1], [2, 2], [2, 3]],
            [[3, 3], [3, 1], [3, 2], [0, 3]],
        ]).long()

        result_i_want = torch.Tensor([
            [42,  1,  2, 40],
            [ 3,  4,  5,  6],
            [ 7,  8,  9,  1],
            [43,  2,  3, 41]
        ])

What operation would I perform on input_image and coords_to_look_up to produce result_i_want? I found index_select, but it seems to only deal with 1D cases.

Thanks for any help. :blush:

Direct indexing should work:

input_image[coords_to_look_up[:, :, 0], coords_to_look_up[:, :, 1]]

Thank you! Your solution looks better than the monstrosity I just came up with, which I’ve put here as some sort of Friday horror show for you:

        x_width = input_image.size(1)
        y_width = input_image.size(0)
        input_array_flat = input_image.flatten(0)
        coord_sources_flat = coords_to_look_up.flatten(0)
        mult_y_by_x_size = torch.tensor([x_width, 1]).flatten(0).repeat(x_width * y_width)
        coords_y_scaled = coord_sources_flat * mult_y_by_x_size
        just_x_coords        = torch.as_strided(coords_y_scaled, (x_width * y_width,), (2,), 1)
        just_y_scaled_coords = torch.as_strided(coords_y_scaled, (x_width * y_width,), (2,))
        indices_flat = just_x_coords + just_y_scaled_coords
        result_i_want_flat = input_array_flat.index_select(0, indices_flat)
        
        result_i_want = result_i_want_flat.reshape((y_width, x_width))
1 Like