Advanced indexing with torch.topk

Would this work:

x = torch.randn(3, 10, 10)
idx = torch.topk(x, k=2, dim=0)[1]
x.scatter_(0, idx, 100)
print(x)
10 Likes