Hi, I’m facing indexing problem. How can I optimize this problem.
I want to select the indices for each batch.
For Example,
x = [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]
idx = [[0, 2], [2, 1]]
...
doing indexing
...
result = [ [[1, 2], [5, 6]], [[11, 12], [9, 10]] ]
I coded indexing with for loop. But it is very slow.
So, the real problem is optimize this for loop.
# x is 4 dim tensor.
x = torch.empty((B, C, W, H))
idx = torch.tensor([[3, 5, 10],[2, 4, 1]])
result = []
for row, i in enumerate(idx):
result += [x[row, i, :, :]]
Thank you!