Differentiable Sorting and Indices

I have a decoder outputing a tensor (let’s call it A) of size [batch_size, num_points, 3], basically a Point-Cloud in three-dimensional euclidean space. I’d like to perform sorting on each of the generated PCs with respect to a specific axis (let’s say z-axis) before feeding the now sorted (named B) [batch_size, num_points, 3] tensor to the next network layers.
The problem arise because I’ve no idea on how to differentiably sort A in one step with torch.sort(). Differently, I tried also to reduce A to a new tensor of size [batch_size, num_points, 1], keeping only z-axis values, performing sort and then using the retrieved indices again on the A tensor. By doing so I’ve noticed that while the values returned by torch.sort have requires_grad=True, the same is not true for indices. How this impact on general differentiability? Is there a better way to sort each of the generated PCs with respect to a single axis (I mean xyz here) without loosing differentiability?
Thank you!

1 Like

You can imagine sort() as doing two operations:
indices = argsort(x,dim)
y = gather(x,indices,dim) #x[indices] may work too

argsort() defines non-trainable permutation
gather() applies this permutation. It is a differentiable operation - backward pass permutes gradients back to match their original positions.
Note that you can use indices to gather from other tensors, that basically performs sortByKey() operation. So what you did, sorting z and reordering bigger tensor is valid.

Perfect! So, since the effect of sort() in backward is just applying the same permutation to the gradients according to the new indices, the two operations (atomic/one instruction sort and two steps sort) are perfectly equivalent. I should have checked the resulting gradient by myself!