How to efficiently retrieve all values in tensor B, corresponding to the same value in tensor A

Hi Damiaan!

Take an == “outer product” of samples with samples.unique() (using
broadcasting) to get a mask of the unique-value locations. Use where()
to get the corresponding indices and use them to index into samples.
Then split that result using the unique-value counts.

Thus:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> samples = torch.tensor([50, 60, 40, 50, 50, 60])
>>> labels = torch.tensor([0, 0, 1, 2, 1, 1])
>>>
>>> vals, cnts = samples.unique (return_counts = True)
>>>
>>> umask = samples.unsqueeze(0) == vals.unsqueeze (1)
>>> umask   # just to check what umask looks like
tensor([[False, False,  True, False, False, False],
        [ True, False, False,  True,  True, False],
        [False,  True, False, False, False,  True]])
>>>
>>> inds = torch.where (umask)[1]
>>> inds    # just to check what inds looks like
tensor([2, 0, 3, 4, 1, 5])
>>>
>>> # index into labels and split according to unique-value counts
>>> labels[inds].split (cnts.tolist())
(tensor([1]), tensor([0, 2, 1]), tensor([0, 1]))

Best.

K. Frank

2 Likes