Hi, I’m facing indexing problem. How can I optimize this problem.

I want to select the indices for each batch.

For Example,

```
x = [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]
idx = [[0, 2], [2, 1]]
...
doing indexing
...
result = [ [[1, 2], [5, 6]], [[11, 12], [9, 10]] ]
```

I coded indexing with for loop. But it is very slow.

So, the real problem is optimize this for loop.

```
# x is 4 dim tensor.
x = torch.empty((B, C, W, H))
idx = torch.tensor([[3, 5, 10],[2, 4, 1]])
result = []
for row, i in enumerate(idx):
result += [x[row, i, :, :]]
```

Thank you!