Masked argmax in Pytorch?

May I ask how to do masked argmax in Pytorch?

For example I have a tensor t and a mask m:
t = torch.tensor([20, 10, 50, 40])
m = numpy.array([True, True, False, True])
The masked argmax is 3 (corresponds to 40 in t)

Thank you very much.

hmm

can’t you just multiply them to make

[20, 10, 0, 40]

and then do argmax?

Thanks. I need to convert m to tensor then.

Hello Dragon!

This won’t work if tensor t is negative (or, more precisely, if its
largest unmasked element is negative).

I would do this:

large = torch.finfo (t.dtype).max   # assumes t is a kind of float
# assume msk has zeros where elements t should be masked out
# and ones where they should be kept
(t - large * (1 - msk) - large * (1 - msk)).argmax()

Best.

K. Frank

If you do not want to convert m to a tensor you can use

def fn(): 
    t = torch.randn(10000) 
    m = np.random.rand(10000) < 0.5 
    return (t==t[m].max()).type(torch.FloatTensor).argmax()  

timeit.timeit(fn, number = 10000)                                                                                                                                                    
   3.285314051026944

while if you convert m to tensor (using K. Franks code : )

In [43]: def gn(): 
    ...:     t = torch.randn(10000) 
    ...:     m = torch.from_numpy(np.random.rand(10000) < 0.5) 
    ...:     large = torch.finfo(t.dtype).max 
    ...:     return (t - large * (~m) - large * (~m)).argmax() 

timeit.timeit(gn, number = 10000)                                                                                                                                                    
    2.864162279991433

Note - benchmarking done on cpu

1 Like

Thank you.
Do we need two - large * (1 - msk)'s or one is good enough?

Thank you for the solution. Converting m to tensor is faster in my case as well.

Hi Smile!

The second - large * (1 - msk) is protection against an edge case.
If one of the masked values in your tensor were equal to (or close
enough to) large, then a single - large * (1 - msk) would reduce
it only to zero, so if all of the unmasked values in your tensor were
negative, you would incorrectly get 0.0 as the maximum value.

Best.

K. Frank

1 Like