Replace several values in tensor

I have a tensor like

a = torch.Tensor([3,2,7,9,2,9, 9])

However, I need the values to start from 0 and go to n-1 where n is the number of unique elements in a. One possible way to compute this seems to be:

r = torch.zeros_like(a)
for i,el in enumerate(torch.unique(a)):
    r[a == el] = i

giving

tensor([1., 0., 2., 3., 0., 3., 3.])

Is there a more efficient way to compute this (possibly avoiding the unique and loop which takes a bit of time as I am having a quite big number of different elements ~10k and a very large a) ?

Thanks!

Hi Alex!

There’s no way you can avoid unique() (or its equivalent). That’s
an inherent bit of processing required by your problem.

You can avoid the loop – depending on the details of your use case,
such as having the values in a being non-negative integers that are
not ridiculously large – by building a look-up table and indexing into
it to translate the values in a:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> a = torch.tensor ([3, 2, 7, 9, 2, 9, 9])
>>> a_unique = torch.unique (a)
>>> lt = torch.full ((a.max() + 1, ), -1)
>>> lt[a_unique] = torch.arange (len (a_unique))
>>> r = lt[a]
>>> r
tensor([1, 0, 2, 3, 0, 3, 3])

(If the elements of a are integers but some are negative, you can
still use the look-up-table approach by shifting the values of a up
so that they become non-negative.)

Best.

K. Frank

1 Like