Get topk results from segmentation one-hot masks, but keep dims

I’m not sure someone asked it before…
Says we have a one-hot segmentation mask with BxNxHxW, where B=batch, N=class-categories, H=height, W=width of the mask. I’m trying to find the top 3 maximize results from each one-hot element-wise vector of the map, and let other unselected elements be 0 (which means we want the dimension not to change).
It is like a keepdim version of the torch.topk.

I didn’t find a way to solve this… I appreciate any advice in here, thanks.

Hi CBGxd!

I’m not at all sure what you are asking here. A one-hot vector is a
vector that consists of all zeroes except for one one. So it doesn’t
make sense to ask for the three largest values (or topk() in general)
in a one-hot vector.

Could you illustrate what you are asking with explicit numerical tensors?

Best.

K. Frank

Thanks for your feedback.
Let’s just ignore the batch issue in here.
Says we have a segmentation map, which is 320x320. And the model’s prediction could be in a one-hot fashion, which is 21x320x320 (the 21 is the class category).
At the moment, I want to find the top 3 maximize elements along the first dimension (i.e., 21), for the entire map. Meanwhile, I want all the other’s non-selected elements filled to zero. Therefore, the result remains the same architecture, as 21x320x320.
It will be more like a topk mask for the original one-hot result.

I think I find the solution. To anyone who might encounter similar issues, it works like this:
1). find the topk index of the one-hot segmentation mask:
_, index = torch.topk(one-hot-result, k= 3, dim=1)
2). expand to the desired class:
expand = torch.nn.functional.one_hot(index.squeeze())
topk_mask = expand.sum(dim=1) # note: due to we have k maximize value, I deal it as multi-label
3). multiply with original one-hot result:
topk_result = torch.mul(topk_mask, one-hot-result)