Let’s say I have a 2D tensor X = [[1, 2, 3, 4], [1, 2, 3, 4], [3, 4, 5, 6]]. If I use the operation output, inverse_indices = torch.unique(X), sorted=True, return_inverse=True, dim=0), then output would be [[1, 2, 3, 4], [3, 4, 5, 6]] and inverse_indices would be [0, 0, 1].
That is, inverse_indices is a tensor of the same length as X and tells us the index of elements of X with respect to the tensor output.
Is there a way to obtain the “reverse” of inverse indices? That is, I want to obtain the indices of output with respect to X. For example, in this case, I want to obtain the indices as [0, 2] (or [1, 2]). I understand that this “reverse” inverse index is not unique since multiple elements of X can be mapped to the same element in output.
It’s a python for loop, so not very efficient, but at least it remains $O(n)$:
output, inverse_indices = torch.unique(X, return_inverse=True)
indices = torch.zeros_like(output, dtype=torch.int)
for i in range(indices.shape[0]):
indices[inverse_indices[i]] = i
@ggoossen function works really well, much faster than others.
By the way it may be minor details at this point but there is a way to make it 10x faster from my testing by not including self :
def torch_unique_with_indices(tensor):
"""Return the unique elements of a tensor and their indices."""
unique, inverse_indices = torch.unique(tensor, return_inverse=True)
indices = torch.scatter_reduce(
torch.zeros_like(unique, dtype=torch.long, device=tensor.device),
dim=0,
index=inverse_indices,
src=torch.arange(tensor.size(0), device=tensor.device),
reduce="amin",
include_self=False,
)
return unique, inverse_indices, indices
Time unique
10.1 ms ± 520 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Time unique with indices include self
156 ms ± 1.61 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Time unique with indices not include self
12.3 ms ± 75.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
There is only a 20% overhead compared to torch.unique in my test with a random int array of shape (1000000).
My guess is using torch.zeros_like is faster than torch.full (possible with include_self=False only)