Dimensions in torch.topk()

topk takes the top k over a single dimension. So if you want to take the top k over the two spatial dimensions, you need to .view(…) your tensor to combine them to one and then “unravel” the indices.

Best regards

Thomas

1 Like