How to use torch.topk() to set non-topk values of a tensor to zero?

Not sure if this is the best way, but I am now using:

 torch.zeros(2, 5).scatter_(indices, topk)

However, this seems to mess up any autograd related to these variables…