Find the topk over masked input

Hi, I have an input tensor X with dimensions (B, H,L,L).
I also have a binary mask M with same dimension. Each row may have diffenret number of active values.

I would like to get the topk values over the last dim, only over the masked value!.
The issue is that i want to be efficent, and avoid from making unesseriy operations.

Can someone help?
Thanks!

Is this what you are looking for:

import torch

def masked_topk(X, M, k):
    B, H, L, _ = X.shape
    
    masked_X = X * M    
    masked_X = masked_X.masked_fill(~M.bool(), float('-inf'))
    
    topk_values, topk_indices = torch.topk(masked_X, k=k, dim=-1)
    
    valid_mask = topk_values != float('-inf')
    
    return topk_values, topk_indices, valid_mask

#Run:
B, H, L = 2, 3, 5
k = 3

X = torch.rand(B, H, L, L)
M = torch.randint(0, 2, (B, H, L, L)).float() #Mask

topk_values, topk_indices, valid_mask = masked_topk(X, M, k)

hi,

thanks on reply, but no … It is a bit more complex issue :frowning:
Maybe i was not very clear in the problem definition.

-This topk mode is not efficeint, although it will give you the right solution…
It will give you the same complexity as working without the mask.

Ideally I looking for solution where the complexity of the topk will be influence with the mask!
“torch.topk(input, k, mask, dim)”
It should be much faster to find the topk with a handful of inputs (just the active) instead of looking all over the entire dimension.

note:
-I have considered to re-write the torch topk, but it seems to be big task. It has also cpp + cuda code. and I did not find documentation.

Thanks!