Would this work:
x = torch.randn(3, 10, 10) idx = torch.topk(x, k=2, dim=0)[1] x.scatter_(0, idx, 100) print(x)