Index selection usage

I’m trying to select values out of a tensor based on in indices stored in another tensor.

outputs = torch.FloatTensor([[0.3, 0.7], [0, 1], [0.4, 0.6]])
labels = torch.Tensor([0, 1, 1]).long()

I want to select based on labels as indices at the 1st dimension. So I want outputs to be torch.tensor[[0.3], [1], [0.6]].

I tried the code below:

outputs = torch.FloatTensor([[0.3, 0.7], [0, 1], [0.4, 0.6]])
labels = torch.Tensor([0, 1, 1]).long()
print(torch.index_select(outputs, 1, labels))

but the output is -

tensor([[0.3000, 0.7000, 0.7000], 
        [0.0000, 1.0000, 1.0000], 
        [0.4000, 0.6000, 0.6000]])

So how can I get desired results?

Thanks

torch.gather should do the job:

outputs = torch.FloatTensor([[0.3, 0.7], [0, 1], [0.4, 0.6]])
labels = torch.Tensor([0, 1, 1]).long()
print(outputs.gather(1, labels.unsqueeze(1)))
>tensor([[0.3000],
        [1.0000],
        [0.6000]])

Thanks as always :slight_smile: