Lexicographic max operator

Hi,

I’m trying to implement a lexicographic max operator that does the following:
given a tensor of N dimensional elements, return the first non-zero value according to the given order.
so for example:
input = [[0, 1, 2],
[1, 4, 2],
[0, 0, 3],
[0, 3, 1]]
output = [[1],
[1],
[3],
[3]]
is there any existing method to do so? if not, could you suggest a way to efficiently implement such a thing (that would also work for “batch” inputs).

Thanks!

I’m unsure if any built-in method can be easily used. tensor.scatter_reduce_ could be a candidate, but would also use the reduce argument to select the right values thus creating an error in the last row since you want to return the 3 while the 1 would be the min. value:

x = torch.tensor([[0, 1, 2],
                  [1, 4, 2],
                  [0, 0, 3],
                  [0, 3, 1]])

idx = x.nonzero()
out = torch.zeros(4).long()
out.scatter_reduce_(dim=0, index=idx[:, 0], src=x[idx[:, 0], idx[:, 1]], reduce="amin", include_self=False)
print(out)
# tensor([1, 1, 3, 1])

EDIT:
OK, this should work:

idx = x.nonzero()
index = torch.zeros(x.size(0)).long()
index.scatter_reduce_(dim=0, index=idx[:, 0], src=idx[:,1], reduce="amin", include_self=False)
print(index)
# tensor([1, 0, 2, 1])

x[torch.arange(x.size(0)), index]
# tensor([1, 1, 3, 3])