For example I have a tensor `t` and a mask `m`:
`t = torch.tensor([20, 10, 50, 40])`
`m = numpy.array([True, True, False, True])`
The masked argmax is 3 (corresponds to 40 in `t`)

Thank you very much.

hmm

can’t you just multiply them to make

[20, 10, 0, 40]

and then do argmax?

Thanks. I need to convert `m` to tensor then.

Hello Dragon!

This won’t work if tensor `t` is negative (or, more precisely, if its

I would do this:

``````large = torch.finfo (t.dtype).max   # assumes t is a kind of float
# assume msk has zeros where elements t should be masked out
# and ones where they should be kept
(t - large * (1 - msk) - large * (1 - msk)).argmax()
``````

Best.

K. Frank

If you do not want to convert m to a tensor you can use

``````def fn():
t = torch.randn(10000)
m = np.random.rand(10000) < 0.5
return (t==t[m].max()).type(torch.FloatTensor).argmax()

timeit.timeit(fn, number = 10000)
3.285314051026944

``````

while if you convert m to tensor (using K. Franks code : )

``````In : def gn():
...:     t = torch.randn(10000)
...:     m = torch.from_numpy(np.random.rand(10000) < 0.5)
...:     large = torch.finfo(t.dtype).max
...:     return (t - large * (~m) - large * (~m)).argmax()

timeit.timeit(gn, number = 10000)
2.864162279991433

``````

Note - benchmarking done on cpu

1 Like

Thank you.
Do we need two `- large * (1 - msk)`'s or one is good enough?

Thank you for the solution. Converting `m` to tensor is faster in my case as well.

Hi Smile!

The second `- large * (1 - msk)` is protection against an edge case.
If one of the masked values in your tensor were equal to (or close
enough to) `large`, then a single `- large * (1 - msk)` would reduce
it only to zero, so if all of the unmasked values in your tensor were
negative, you would incorrectly get `0.0` as the maximum value.

Best.

K. Frank

1 Like