Hello Dragon!
This won’t work if tensor t
is negative (or, more precisely, if its
largest unmasked element is negative).
I would do this:
large = torch.finfo (t.dtype).max # assumes t is a kind of float
# assume msk has zeros where elements t should be masked out
# and ones where they should be kept
(t - large * (1 - msk) - large * (1 - msk)).argmax()
Best.
K. Frank