# How to efficiently retrieve all values in tensor B, corresponding to the same value in tensor A

Let’s say I have two tensors:

``````samples = torch.tensor([50, 60, 40, 50, 50, 60])
labels = torch.tensor([0, 0, 1, 2, 1, 1])
``````

What I would like to do is to efficiently retrieve the values in labels for each unique value in samples. So the expected output would be three tensors (assuming ordered samples, such as when using torch.unique):

``````torch.tensor([1])
torch.tensor([0, 2, 1])
torch.tensor([0, 1])
``````

Now, of course, I could do torch.unique(samples) to get the unique samples, and use this in a loop to index labels, for example labels[samples == 50] yields the expected torch.tensor([0, 2, 1]). However, I am not sure if that is efficient as I am still using a for loop. Should I do something smart with the return_inverse parameter of torch.unique? I haven’t solved the puzzle yet!

A probably even more inefficient way would be something like this (this is what I currently do):

``````occurrences = defaultdict(list)

for sample, label in zip(samples, labels):
key = sample.flatten().tolist()
occurrences[tuple(key)].append(label)
``````

In my specific problem, samples is either a Python list of 2D tensors or a large 3D tensor (based on whether the tensors could be stacked, since the last dimension may vary depending on the configuration). Why I would like to collect the occurrences in labels for every unique sample is because I would like to calculate the entropy of these labels. So, currently, after the above code, I have something like this:

``````for sample, label_occurrences in occurrences.items():
label_occurrences = torch.tensor(label_occurrences)
_, counts = torch.unique(label_occurrences, return_counts=True)
etr = entropy(counts / label_occurrences.shape[0], base=num_classes)
``````

This also feels very inefficient. How can I make my code more efficient(/vectorised)?

(Bonus: I am using SciPy’s entropy function in the last snippet; probably there are even faster ways to calculate these numbers using Torch’s built-in operations?)

Hi Damiaan!

Take an `==` “outer product” of `samples` with `samples.unique()` (using
broadcasting) to get a mask of the unique-value locations. Use `where()`
to get the corresponding indices and use them to index into `samples`.
Then split that result using the unique-value counts.

Thus:

``````>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> samples = torch.tensor([50, 60, 40, 50, 50, 60])
>>> labels = torch.tensor([0, 0, 1, 2, 1, 1])
>>>
>>> vals, cnts = samples.unique (return_counts = True)
>>>
>>> umask = samples.unsqueeze(0) == vals.unsqueeze (1)
tensor([[False, False,  True, False, False, False],
[ True, False, False,  True,  True, False],
[False,  True, False, False, False,  True]])
>>>
>>> inds    # just to check what inds looks like
tensor([2, 0, 3, 4, 1, 5])
>>>
>>> # index into labels and split according to unique-value counts
>>> labels[inds].split (cnts.tolist())
(tensor([1]), tensor([0, 2, 1]), tensor([0, 1]))
``````

Best.

K. Frank

2 Likes

Wow!!! This is clever! That was cool to see!

Thank you so much!