Masked argmax in Pytorch?

Hello Dragon!

This won’t work if tensor t is negative (or, more precisely, if its
largest unmasked element is negative).

I would do this:

large = torch.finfo (t.dtype).max   # assumes t is a kind of float
# assume msk has zeros where elements t should be masked out
# and ones where they should be kept
(t - large * (1 - msk) - large * (1 - msk)).argmax()

Best.

K. Frank