I feel like the solution should be super simple but I can’t quite figure it out?
Given a float tensor, and indices from torch.topk of the top k values, in this case 2, how could I use the long tensor indices of torch.topk to set the values of x that aren’t in the topk to zero either in-place, or as a new object?
x = torch.arange(0,10).resize_((2,5))
0 1 2 3 4
5 6 7 8 9
[torch.FloatTensor of size 2x5]
topk, indices = torch.topk(x, 2)
The final product from all this should look like:
0 0 0 3 4
0 0 0 8 9
Of couse this one is strcuture to always be the last two elements of each row, but in my case, it will vary randomly.
Thanks in advance!