I have a tensor, A
, of shape B
x N
x 2. For instance:
B = 4
N = 5
A = torch.randn(B, N, 2) #(B x N x 2) Tensor
I then have a list of indices, that ought to sometimes ‘swap’ the positions of the last column of the tensor:
I = torch.tensor([[0,1], [1,0], [1,0], [0,1]]) #(B x 2) Tensor
The [0,1]
indices ought to do nothing. The [1,0]
indices ought to switch the last two columns.
How do I index A
with I
to perform this operation in Pytorch?
Example:
Suppose I have
print(A[0, :, :])
tensor([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10]]])
and
print(A[1, :, :])
tensor([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10]]])
Then foo(A, I)
, the indexing operation, should result in nothing happening to the first batch element ([0,1]
), and a switch for the second batch element ([1,0]
),:
print(foo(A, I)[0, :, :]))
tensor([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10]]])
and
print(foo(A, I)[1, :, :]))
tensor([
[2, 1],
[4, 3],
[6, 5],
[8, 7],
[10, 9]]])