I’m not sure someone asked it before…
Says we have a one-hot segmentation mask with BxNxHxW, where B=batch, N=class-categories, H=height, W=width of the mask. I’m trying to find the top 3 maximize results from each one-hot element-wise vector of the map, and let other unselected elements be 0 (which means we want the dimension not to change).
It is like a keepdim version of the torch.topk.

I didn’t find a way to solve this… I appreciate any advice in here, thanks.

I’m not at all sure what you are asking here. A one-hot vector is a
vector that consists of all zeroes except for one one. So it doesn’t
make sense to ask for the three largest values (or topk() in general)
in a one-hot vector.

Could you illustrate what you are asking with explicit numerical tensors?

Thanks for your feedback.
Let’s just ignore the batch issue in here.
Says we have a segmentation map, which is 320x320. And the model’s prediction could be in a one-hot fashion, which is 21x320x320 (the 21 is the class category).
At the moment, I want to find the top 3 maximize elements along the first dimension (i.e., 21), for the entire map. Meanwhile, I want all the other’s non-selected elements filled to zero. Therefore, the result remains the same architecture, as 21x320x320.
It will be more like a topk mask for the original one-hot result.

I think I find the solution. To anyone who might encounter similar issues, it works like this:
1). find the topk index of the one-hot segmentation mask:
_, index = torch.topk(one-hot-result, k= 3, dim=1)
2). expand to the desired class:
expand = torch.nn.functional.one_hot(index.squeeze())
topk_mask = expand.sum(dim=1) # note: due to we have k maximize value, I deal it as multi-label
3). multiply with original one-hot result:
topk_result = torch.mul(topk_mask, one-hot-result)