I have a PyTorch tensor of shape ((1,512,16,3)).

The (16,3) part of the tensor represents 16 - (x,y,z) points.

These 16 points could be unique points or repeated.

I want to convert it to a list of 512 lists only with unique points.

Dimensions of the resulting list of lists (1,512,p,3) ==> 1 list which has 512 lists each of varying “p” size which in turn contain a list of 3 points.

So far I tried using two for loops to fetch the (16,3) points and use torch.unique to pick the unique points. But 2 for loops make my code slower.

Please let me know if there is a function to do it in PyTorch.

Thanks in advance.