How can I perform batch indexing?

I have one batched tensor with matrices:

batch_size = 64 
num_elements = 500
channels = 2
a = torch.rand(batch_size, num_elements, channels)  # [64, 500, 2]

And one tensor with indices, that describe how to index each previous matrix:

indices = 80
b = (torch.rand(batch_size, indices) * num_elements).long()  # [64, 80]

I want to perform an indexing such that each matrix in the batch is indexed with the corresponding index in the other tensor. The resulting matrix has shape [64, 80, 2]
How can I do that?

1 Like

For two dimension indexing, try following:

B, C1, C2 = 32, 1024, 3
input = torch.rand((B, C1, C2))  # [32, 1024, 3]
idx1 = torch.arange(B).view(-1, 1)  # [32, 1]
idx2 = torch.randint(low=0, high=C1, size=(B, 80))  # [32, 80]
output = input[idx1, idx2]  # [32, 80, 3]

Similarly, for three dimension indexing:

B, C1, C2, C3 = 32, 1024, 16, 3
input = torch.rand((B, C1, C2, C3))  # [32, 1024, 16, 3]
idx1 = torch.arange(B).view(-1, 1, 1)  # [32, 1, 1]
idx2 = torch.arange(C1).view(1, -1, 1)  # [1, 1024, 1]
idx3 = torch.randint(low=0, high=C2, size=(B, C1, 8))  # [32, 1024, 8]
output = input[idx1, idx2, idx3]  # [32, 1024, 8, 3]
2 Likes