How to optimize tensor indexing

Hi, I’m facing indexing problem. How can I optimize this problem.
I want to select the indices for each batch.

For Example,

x = [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]
idx = [[0, 2], [2, 1]]
...
doing indexing
...
result = [ [[1, 2], [5, 6]], [[11, 12], [9, 10]] ]

I coded indexing with for loop. But it is very slow.
So, the real problem is optimize this for loop.

# x is 4 dim tensor.
x = torch.empty((B, C, W, H))
idx = torch.tensor([[3, 5, 10],[2, 4, 1]])

result = []
for row, i in enumerate(idx):
    result += [x[row, i, :, :]]

Thank you!

x = torch.tensor(x)
idx = torch.tensor(idx)
result = torch.gather(x, 1, idx)
1 Like

Yes, in that case I can solve with gather but below case(x is 4-dim) I can’t solve with gather.
I think the old example is not correct to this problem, I changed it.

For the moment I use this solution : Batched index_select but I’m also curious if there is a better alternative.

1 Like

One of my friend solve this problem.

The solution is expanding index matrix and use gather

# for 3-dim input
X = torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
I = torch.tensor([[0, 2], [2, 1]])
eI = I[..., None].expand(-1, -1, X.size(2))    ## expanding index
Y = torch.gather(X, dim=1, index=eI).squeeze()

# for 4-dim input
X = torch.tensor([[[[0, 0], [1, 1]], [[2, 2], [3, 3]], [[4, 4], [5, 5]]], [[[6, 6], [7, 7]], [[8, 8], [9, 9]], [[10, 10], [11, 11]]]])
I = torch.tensor([[0, 2], [2, 1]])
eI = I[..., None, None].expand(-1, -1, X.size(2), X.size(3))    ## expanding index
Y = torch.gather(X, dim=1, index=eI).squeeze()
1 Like