May I ask how to do masked argmax in Pytorch?
For example I have a tensor t
and a mask m
:
t = torch.tensor([20, 10, 50, 40])
m = numpy.array([True, True, False, True])
The masked argmax is 3 (corresponds to 40 in t
)
Thank you very much.
May I ask how to do masked argmax in Pytorch?
For example I have a tensor t
and a mask m
:
t = torch.tensor([20, 10, 50, 40])
m = numpy.array([True, True, False, True])
The masked argmax is 3 (corresponds to 40 in t
)
Thank you very much.
hmm
can’t you just multiply them to make
[20, 10, 0, 40]
and then do argmax?
Thanks. I need to convert m
to tensor then.
Hello Dragon!
This won’t work if tensor t
is negative (or, more precisely, if its
largest unmasked element is negative).
I would do this:
large = torch.finfo (t.dtype).max # assumes t is a kind of float
# assume msk has zeros where elements t should be masked out
# and ones where they should be kept
(t - large * (1 - msk) - large * (1 - msk)).argmax()
Best.
K. Frank
If you do not want to convert m to a tensor you can use
def fn():
t = torch.randn(10000)
m = np.random.rand(10000) < 0.5
return (t==t[m].max()).type(torch.FloatTensor).argmax()
timeit.timeit(fn, number = 10000)
3.285314051026944
while if you convert m to tensor (using K. Franks code : )
In [43]: def gn():
...: t = torch.randn(10000)
...: m = torch.from_numpy(np.random.rand(10000) < 0.5)
...: large = torch.finfo(t.dtype).max
...: return (t - large * (~m) - large * (~m)).argmax()
timeit.timeit(gn, number = 10000)
2.864162279991433
Note - benchmarking done on cpu
Thank you.
Do we need two - large * (1 - msk)
's or one is good enough?
Thank you for the solution. Converting m
to tensor is faster in my case as well.
Hi Smile!
The second - large * (1 - msk)
is protection against an edge case.
If one of the masked values in your tensor were equal to (or close
enough to) large
, then a single - large * (1 - msk)
would reduce
it only to zero, so if all of the unmasked values in your tensor were
negative, you would incorrectly get 0.0
as the maximum value.
Best.
K. Frank