Not sure if this is the best way, but I am now using:
torch.zeros(2, 5).scatter_(indices, topk)
However, this seems to mess up any autograd related to these variables…
Not sure if this is the best way, but I am now using:
torch.zeros(2, 5).scatter_(indices, topk)
However, this seems to mess up any autograd related to these variables…