Switching Indices of Tensor

I have a tensor, A, of shape B x N x 2. For instance:

B = 4
N = 5
A = torch.randn(B, N, 2) #(B x N x 2) Tensor

I then have a list of indices, that ought to sometimes ‘swap’ the positions of the last column of the tensor:

I = torch.tensor([[0,1], [1,0], [1,0], [0,1]]) #(B x 2) Tensor

The [0,1] indices ought to do nothing. The [1,0] indices ought to switch the last two columns.

How do I index A with I to perform this operation in Pytorch?

Example:

Suppose I have

print(A[0, :, :])

tensor([
        [1, 2],
        [3, 4],
        [5, 6],
        [7, 8],
        [9, 10]]])

and

print(A[1, :, :])

tensor([
        [1, 2],
        [3, 4],
        [5, 6],
        [7, 8],
        [9, 10]]])

Then foo(A, I), the indexing operation, should result in nothing happening to the first batch element ([0,1]), and a switch for the second batch element ([1,0]),:

print(foo(A, I)[0, :, :]))

tensor([
        [1, 2],
        [3, 4],
        [5, 6],
        [7, 8],
        [9, 10]]])

and

print(foo(A, I)[1, :, :]))

tensor([
        [2, 1],
        [4, 3],
        [6, 5],
        [8, 7],
        [10, 9]]])

Hi Jack!

You can use torch.gather() (after adjusting the shape of I):

>>> import torch
>>> torch.__version__
'1.10.2'
>>> B = 4
>>> N = 5
>>> A = (torch.arange (N * 2) + 1).reshape (1, N, 2).expand (B, N, 2)
>>> A
tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6],
         [ 7,  8],
         [ 9, 10]],

        [[ 1,  2],
         [ 3,  4],
         [ 5,  6],
         [ 7,  8],
         [ 9, 10]],

        [[ 1,  2],
         [ 3,  4],
         [ 5,  6],
         [ 7,  8],
         [ 9, 10]],

        [[ 1,  2],
         [ 3,  4],
         [ 5,  6],
         [ 7,  8],
         [ 9, 10]]])
>>> I = torch.tensor ([[0, 1], [1, 0], [1, 0], [0, 1]])
>>> torch.gather (A, 2, I.unsqueeze (1).expand (4, 5, 2))
tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6],
         [ 7,  8],
         [ 9, 10]],

        [[ 2,  1],
         [ 4,  3],
         [ 6,  5],
         [ 8,  7],
         [10,  9]],

        [[ 2,  1],
         [ 4,  3],
         [ 6,  5],
         [ 8,  7],
         [10,  9]],

        [[ 1,  2],
         [ 3,  4],
         [ 5,  6],
         [ 7,  8],
         [ 9, 10]]])

Best.

K. Frank