Tensor batch indexing question

It is really difficult to explain my situation. but I will try to do my best:)

I have (128, 1) tensor which includes 128 rows and each row has 0 or 1 value.
And I have another tensor (128, 2). using previous tensor, I want to choose each rows’ value and transformed second tensor to a new tensor (128,1)

how can I achieve this??

1 Like

I think gather would work for you:

x = torch.randn(128, 2)
index = torch.empty(128, 1, dtype=torch.long).random_(2)
x.gather(1, index)
2 Likes