How to use `torch.topk` to select the maximum sum of dim=1 on a 4d tensor without loops

Instead of the for loop to create your selection_tensor, you could also use indexing:

selection_tensor[torch.arange(selection.size(0)), indices.t()] = 1
1 Like