How to efficiently retrieve all values in tensor B, corresponding to the same value in tensor A

Let’s say I have two tensors:

samples = torch.tensor([50, 60, 40, 50, 50, 60])
labels = torch.tensor([0, 0, 1, 2, 1, 1])

What I would like to do is to efficiently retrieve the values in labels for each unique value in samples. So the expected output would be three tensors (assuming ordered samples, such as when using torch.unique):

torch.tensor([0, 2, 1])
torch.tensor([0, 1])

Now, of course, I could do torch.unique(samples) to get the unique samples, and use this in a loop to index labels, for example labels[samples == 50] yields the expected torch.tensor([0, 2, 1]). However, I am not sure if that is efficient as I am still using a for loop. Should I do something smart with the return_inverse parameter of torch.unique? I haven’t solved the puzzle yet!

A probably even more inefficient way would be something like this (this is what I currently do):

occurrences = defaultdict(list)

for sample, label in zip(samples, labels):
    key = sample.flatten().tolist()

In my specific problem, samples is either a Python list of 2D tensors or a large 3D tensor (based on whether the tensors could be stacked, since the last dimension may vary depending on the configuration). Why I would like to collect the occurrences in labels for every unique sample is because I would like to calculate the entropy of these labels. So, currently, after the above code, I have something like this:

for sample, label_occurrences in occurrences.items():
  label_occurrences = torch.tensor(label_occurrences)
  _, counts = torch.unique(label_occurrences, return_counts=True)
  etr = entropy(counts / label_occurrences.shape[0], base=num_classes)

This also feels very inefficient. How can I make my code more efficient(/vectorised)?

(Bonus: I am using SciPy’s entropy function in the last snippet; probably there are even faster ways to calculate these numbers using Torch’s built-in operations?)

Hi Damiaan!

Take an == “outer product” of samples with samples.unique() (using
broadcasting) to get a mask of the unique-value locations. Use where()
to get the corresponding indices and use them to index into samples.
Then split that result using the unique-value counts.


>>> import torch
>>> print (torch.__version__)
>>> samples = torch.tensor([50, 60, 40, 50, 50, 60])
>>> labels = torch.tensor([0, 0, 1, 2, 1, 1])
>>> vals, cnts = samples.unique (return_counts = True)
>>> umask = samples.unsqueeze(0) == vals.unsqueeze (1)
>>> umask   # just to check what umask looks like
tensor([[False, False,  True, False, False, False],
        [ True, False, False,  True,  True, False],
        [False,  True, False, False, False,  True]])
>>> inds = torch.where (umask)[1]
>>> inds    # just to check what inds looks like
tensor([2, 0, 3, 4, 1, 5])
>>> # index into labels and split according to unique-value counts
>>> labels[inds].split (cnts.tolist())
(tensor([1]), tensor([0, 2, 1]), tensor([0, 1]))


K. Frank


Wow!!! This is clever! That was cool to see! :slight_smile:

Thank you so much!

A quick addition:

In practice, if samples do actually not consist of numbers, but, e.g., stacked multi-dimensional tensors (let’s say every sample is torch.Size([26, 50]) and the example samples tensor of tensor samples equals to torch.Size([6, 26, 50]) in this case), then one can similarly use the == operator like Frank explains, but needs to apply .all(dim=n) for every dimension ‘belonging’ to the sample itself and not the indexing of the sample (in the example .all(dim=3).all(dim=2)).

However, the above is computationally way too expensive for real-world datasets, even if they are relatively small. Therefore, in the particular context of my question, a solution would be to pass an extra return_inverse=True parameter to the torch.unique() call (as in Frank’s code) and use the inverse counts as a replacement for samples. A light-weight replacement for vals can then simply be the range of the size of the first dimension of vals (i.e., torch.arange(0, vals.shape[0]). This should solve the memory problem for many cases. In one part of my particular situation, the samples tensor is 7,225,540 elements long, with 45,192 unique vals, so I have now implemented a hybrid setting to handle these extreme cases (determined by the availability of RAM on my machine :smiley: ), just using the simple loop. Would be great if .eq() would yield booleans that take a bit instead of a byte, haha!