Dimensions in torch.topk()

I have a (n, n) mask and corresponding (3, n, n) image. I’m trying to use the torch.topk() function to get the top k sigmoid pixel values and erase them from the corresponding image. However, I’m struggling quite a bit trying to understand how topk() deals with dimensions… To properly compute this, do I need to unsqueeze the matrix? Thanks!

topk takes the top k over a single dimension. So if you want to take the top k over the two spatial dimensions, you need to .view(…) your tensor to combine them to one and then “unravel” the indices.

Best regards

Thomas

1 Like