I think I find the solution. To anyone who might encounter similar issues, it works like this:
1). find the topk index of the one-hot segmentation mask:
_, index = torch.topk(one-hot-result, k= 3, dim=1)
2). expand to the desired class:
expand = torch.nn.functional.one_hot(index.squeeze())
topk_mask = expand.sum(dim=1) # note: due to we have k maximize value, I deal it as multi-label
3). multiply with original one-hot result:
topk_result = torch.mul(topk_mask, one-hot-result)