I have a tensor, `A`

, of shape `B`

x `N`

x 2. For instance:

```
B = 4
N = 5
A = torch.randn(B, N, 2) #(B x N x 2) Tensor
```

I then have a list of indices, that ought to sometimes ‘swap’ the positions of the last column of the tensor:

`I = torch.tensor([[0,1], [1,0], [1,0], [0,1]]) #(B x 2) Tensor`

The `[0,1]`

indices ought to do nothing. The `[1,0]`

indices ought to switch the last two columns.

How do I index `A`

with `I`

to perform this operation in Pytorch?

**Example:**

Suppose I have

```
print(A[0, :, :])
tensor([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10]]])
```

and

```
print(A[1, :, :])
tensor([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10]]])
```

Then `foo(A, I)`

, the indexing operation, should result in nothing happening to the first batch element (`[0,1]`

), and a switch for the second batch element (`[1,0]`

),:

```
print(foo(A, I)[0, :, :]))
tensor([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10]]])
```

and

```
print(foo(A, I)[1, :, :]))
tensor([
[2, 1],
[4, 3],
[6, 5],
[8, 7],
[10, 9]]])
```