How to use torch.topk() to set non-topk values of a tensor to zero?

I feel like the solution should be super simple but I can’t quite figure it out?
Given a float tensor, and indices from torch.topk of the top k values, in this case 2, how could I use the long tensor indices of torch.topk to set the values of x that aren’t in the topk to zero either in-place, or as a new object?

 x = torch.arange(0,10).resize_((2,5))

 0  1  2  3  4
 5  6  7  8  9
 [torch.FloatTensor of size 2x5]
 
 topk, indices = torch.topk(x, 2)

The final product from all this should look like:

 0 0 0 3 4
 0 0 0 8 9

Of couse this one is strcuture to always be the last two elements of each row, but in my case, it will vary randomly.

Thanks in advance!

Not sure if this is the best way, but I am now using:

 torch.zeros(2, 5).scatter_(indices, topk)

However, this seems to mess up any autograd related to these variables…

Your solution should work, but instead of using the inplace operation on Variables, it might be better to use them out-of-place.

res = Variable(torch.zeros(2, 5))
res = res.scatter(1, indices, topk)
16 Likes

Suppose if I want to change the sign of the topk elements, is there a simple way to do it? The only way I am aware of is to use a for-loop but I believe it would be inefficient.