Hi Damiaan!
Take an ==
“outer product” of samples
with samples.unique()
(using
broadcasting) to get a mask of the unique-value locations. Use where()
to get the corresponding indices and use them to index into samples
.
Then split that result using the unique-value counts.
Thus:
>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> samples = torch.tensor([50, 60, 40, 50, 50, 60])
>>> labels = torch.tensor([0, 0, 1, 2, 1, 1])
>>>
>>> vals, cnts = samples.unique (return_counts = True)
>>>
>>> umask = samples.unsqueeze(0) == vals.unsqueeze (1)
>>> umask # just to check what umask looks like
tensor([[False, False, True, False, False, False],
[ True, False, False, True, True, False],
[False, True, False, False, False, True]])
>>>
>>> inds = torch.where (umask)[1]
>>> inds # just to check what inds looks like
tensor([2, 0, 3, 4, 1, 5])
>>>
>>> # index into labels and split according to unique-value counts
>>> labels[inds].split (cnts.tolist())
(tensor([1]), tensor([0, 2, 1]), tensor([0, 1]))
Best.
K. Frank