Index of first occurrence on sorted 1D tensor

Hello everyone!
I’m trying to implement the following function as efficiently as possible:
Given I have a sorted 1D tensor of positive integers, I want to retrieve the index of the first occurrence of each value. In case any integer between 0 and N-1 is missing, I want to return a -1. Thus, the function would return a 1D tensor of N elements.
Is there an efficient way of doing this better than just a for loop over all values?

E.g.

x = torch.tensor([2, 3, 3, 4, 6, 6, 6, 8, 8], dtype=torch.long)
y = arg_first_ocurrence(x, N=10)

would return

>>> y
tensor([-1, -1, 0, 1, 3, -1,  4, -1, 7, -1])

I’m open to solutions that would involve coding a C++/CUDA extension.

Thank you for your help!

Hi tdigt!

I don’t know whether this meets your efficiency requirements, but
here is a scheme that uses no loops, only pytorch tensor operations.

Note that msk = x == lbl.unsqueeze (1) uses broadcasting,
and therefore generates a two-dimensional tensor from your
one-dimensional input, possibly introducing inefficiency.

Also for convenience, I’ve eliminated the N argument, instead just
using the length of the input tensor.

Here is the code in a small test script:

import torch

torch.__version__

def arg_first_occurrence (x):
    n = x.numel()
    lbl = torch.arange (n)
    msk = x == lbl.unsqueeze (1)
    mna = (lbl + 1) * msk
    mnb = torch.where (mna != 0, mna, (n + 1) * torch.ones (n, 1).long())
    foa = mnb.min (dim = 1)[0] - 1
    fob = torch.where (foa != n, foa, -torch.ones (n).long())
    return fob

x = torch.tensor([2, 3, 3, 4, 6, 6, 6, 8, 8], dtype=torch.long)

y = arg_first_occurrence (x)

ychk = torch.tensor([-1, -1, 0, 1, 3, -1,  4, -1, 7], dtype = torch.long)

print ('torch.equal (y, tchk) =', torch.equal (y, ychk))

And here is its output:

>>> import torch
>>>
>>> torch.__version__
'1.7.1'
>>>
>>> def arg_first_occurrence (x):
...     n = x.numel()
...     lbl = torch.arange (n)
...     msk = x == lbl.unsqueeze (1)
...     mna = (lbl + 1) * msk
...     mnb = torch.where (mna != 0, mna, (n + 1) * torch.ones (n, 1).long())
...     foa = mnb.min (dim = 1)[0] - 1
...     fob = torch.where (foa != n, foa, -torch.ones (n).long())
...     return fob
...
>>>
>>> x = torch.tensor([2, 3, 3, 4, 6, 6, 6, 8, 8], dtype=torch.long)
>>>
>>> y = arg_first_occurrence (x)
>>>
>>> ychk = torch.tensor([-1, -1, 0, 1, 3, -1,  4, -1, 7], dtype = torch.long)
>>>
>>> print ('torch.equal (y, tchk) =', torch.equal (y, ychk))
torch.equal (y, tchk) = True

Best.

K. Frank