Tf.unique_with_counts PyTorch equivalent?

Good morning,

I’m translating code from TensorFlow to PyTorch. There is a function in TF, unique_with_counts, that I can best explain with the example from the TF documentation tf.unique_with_counts  |  TensorFlow v2.12.0

import tensorflow as tf

x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8], tf.float32)
print('\n' + 'x: ')
print(x)

unique_vals, unique_idxs, unique_val_counts = tf.unique_with_counts(x)

print('\n' + 'unique_vals: ')
print(unique_vals)

print('\n' + 'unique_idxs: ')
print(unique_idxs)

print('\n' + 'unique_val_counts: ')
print(unique_val_counts)

output:

x: 
tf.Tensor([1. 1. 2. 4. 4. 4. 7. 8. 8.], shape=(9,), dtype=float32)

unique_vals: 
tf.Tensor([1. 2. 4. 7. 8.], shape=(5,), dtype=float32)

unique_idxs: 
tf.Tensor([0 0 1 2 2 2 3 4 4], shape=(9,), dtype=int32)

unique_val_counts: 
tf.Tensor([2 1 3 1 2], shape=(5,), dtype=int32)

What is the best way to replicate this function in PyTorch?

Well this was an easy one, it looks like simply changing some of the default parameters to torch.unique achieves the same functionality:

import torch

x = torch.tensor([1, 1, 2, 4, 4, 4, 7, 8, 8], dtype=torch.float32)

unique_vals, unique_idxs, unique_val_counts = torch.unique(x, sorted=True, return_inverse=True, return_counts=True)

print('\n' + 'unique_vals: ')

print(unique_vals)

print('\n' + 'unique_idxs: ')

print(unique_idxs)

print('\n' + 'unique_val_counts: ')

print(unique_val_counts)

result:

unique_vals: 
tensor([1., 2., 4., 7., 8.])

unique_idxs: 
tensor([0, 0, 1, 2, 2, 2, 3, 4, 4])

unique_val_counts: 
tensor([2, 1, 3, 1, 2])