Hi, I have an input tensor X with dimensions (B, H,L,L).
I also have a binary mask M with same dimension. Each row may have diffenret number of active values.
I would like to get the topk values over the last dim, only over the masked value!.
The issue is that i want to be efficent, and avoid from making unesseriy operations.
thanks on reply, but no … It is a bit more complex issue
Maybe i was not very clear in the problem definition.
-This topk mode is not efficeint, although it will give you the right solution…
It will give you the same complexity as working without the mask.
Ideally I looking for solution where the complexity of the topk will be influence with the mask!
“torch.topk(input, k, mask, dim)”
It should be much faster to find the topk with a handful of inputs (just the active) instead of looking all over the entire dimension.
note:
-I have considered to re-write the torch topk, but it seems to be big task. It has also cpp + cuda code. and I did not find documentation.