Indexing a tensor based on another tensor

Super basic question, but I am having problems with it.

Essentially I have a tensor called probs with shape (batch_size, classes) let’s say (4, 3), and another tensor called labs with shape (batch_size), let’s say (4). Let’s assume that the values of the tensors are:

a = tensor([[0.1, 0.2, 0.3, 0.2, 0.2],
            [0.4, 0.3, 0.2, 0.1, 0   ],
            [0.1, 0.4, 0.3, 0.1, 0.1],
            [0.2, 0.2, 0.2, 0.1, 0.3]]

and

labs = tensor([2, 1, 0, 3])

I want the result to be:

result = tensor([0.3,
                [0.3,
                [0.1, 
                [0.1])

essentially getting the 2nd (zero indexing) value from the zeroth row, first value from the second row, zeroth value from the second row and the third value from the third row.

Does anyone know what is the best way of achieving this (not using for loops)? I thought that

probs[labs.unsqueeze(1)]

should do the trick, but apparently not. Also, tried with gather, but no luck so far.

you can create an auxiliary vector of indices:
add_idx = torch.tensor([0,5,10,15])

then indices the tensor a in this way:
result = a.flatten()[labs+add_idx]

After playing a bit with gather, I got something that almost gives me what I need:

result = a.gather(1, labs.unsqueeze(1))

result has shape (4, 1) instead of (4) though. Obviously this can be solved by using torch.squeeze function, however I am interested to know if there is a more efficient version that achieves this.

You can use a[range(4), labs].

1 Like