Calculating number of unique values per row

Hi, I’m looking at how to count the number of unique values per row.

E.g. I might have the following matrix

a = torch.tensor([[1, 2, 3, 4, 5, 5, 6],
                           [2, 2, 1, 4, 1, 3, 4],
                           [7, 6, 5, 4, 3, 2, 1]])

The results should look like
result = [6, 4, 7]

At present, I’m doing it like this but the result seems pretty slow when dim0 is large since it’s a for-loop. Is there a better way to do this? torch.unique can’t be directly applied here because the final output has to be an even shape whereas we might have a different number of unique values per row.

result = [torch.tensor(torch.unique(a[i], dim=-1).size(0)) for i in range(a.size(0))]
result = torch.stack(result, dim=0)

Hi, try this…

import torch

a = torch.tensor([[1, 2, 3, 4, 5, 5, 6],
                           [2, 2, 1, 4, 1, 3, 4],
                           [7, 6, 5, 4, 3, 2, 1]])

x = torch.zeros([a.shape[0]])

for i in range(len(a)):
    x[i] = len(a[i].unique())

print(x)

Or without x tensor:

import torch

a = torch.tensor([[1, 2, 3, 4, 5, 5, 6],
                           [2, 2, 1, 4, 1, 3, 4],
                           [7, 6, 5, 4, 3, 2, 1]])

for index, tensor in enumerate(a):
    a[index] = len(tensor.unique())

unique = torch.flatten(a).unique()
print(unique)