Hi, I’m looking at how to count the number of unique values per row.
E.g. I might have the following matrix
a = torch.tensor([[1, 2, 3, 4, 5, 5, 6],
[2, 2, 1, 4, 1, 3, 4],
[7, 6, 5, 4, 3, 2, 1]])
The results should look like
result = [6, 4, 7]
At present, I’m doing it like this but the result seems pretty slow when dim0 is large since it’s a for-loop. Is there a better way to do this? torch.unique
can’t be directly applied here because the final output has to be an even shape whereas we might have a different number of unique values per row.
result = [torch.tensor(torch.unique(a[i], dim=-1).size(0)) for i in range(a.size(0))]
result = torch.stack(result, dim=0)