The torch.unique() function returns the unique values and the inverse indices, but unlike np.unique, doesn’t return the indices of the first occurrence of these unique values. Is there any way of doing this in pytorch without using a for loop for each unique value?
I was wondering the same problem today, and I’ve found a tricky workaround (for sure not elegant).
Here’s the example code:
a = torch.rand(2,2) a = torch.vstack([a, a]) uni, rev_idxs = a.unique(dim=0, return_inverse=True) first_occ_idxs = rev_idxs.unique() print(a[first_occ_idxs])
If anyone can suggest a more elegant solution, I would adopt it.
I came across the PyTorch Unique package, which implements the unique function as follows:
unique, inverse = torch.unique(x, sorted=True, return_inverse=True) perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) inverse, perm = inverse.flip(), perm.flip() perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)