I’m trying to select values out of a tensor based on in indices stored in another tensor.
outputs = torch.FloatTensor([[0.3, 0.7], [0, 1], [0.4, 0.6]])
labels = torch.Tensor([0, 1, 1]).long()
I want to select based on labels as indices at the 1st dimension. So I want outputs to be torch.tensor[[0.3], [1], [0.6]]
.
I tried the code below:
outputs = torch.FloatTensor([[0.3, 0.7], [0, 1], [0.4, 0.6]])
labels = torch.Tensor([0, 1, 1]).long()
print(torch.index_select(outputs, 1, labels))
but the output is -
tensor([[0.3000, 0.7000, 0.7000],
[0.0000, 1.0000, 1.0000],
[0.4000, 0.6000, 0.6000]])
So how can I get desired results?
Thanks