Trainable sorting order?

Is it possible to have a trainable sorting layer in Pytorch?

The idea is that given 8 vectors, I want the layer to learn the best order before applying torch.cat(). See the minimal example I created below.

import torch
import torch.nn as nn

# Current implementation
x= [torch.rand(1,32) for i in range(8)]
x_out= torch.cat(x, dim=0)
display(x_out.shape)

# Something like this?
sort_layer = nn.Parameter(torch.rand(8))
sort_idxs= torch.argsort(sort_layer)
x_out = torch.cat([x[i] for i in sort_idxs])
display(x_out.shape)

In my current implementation, the order is fixed. I tried using the nn.Parameter() with torch.argsort() but the parameter values are not updated during training. Any ideas how this could be done?

Hi Brendan!

Not really (not in the context of using gradient-descent optimization as is
typical with pytorch).

The problem is that “order” isn’t (usefully) differentiable. Let’s say that you
start out with order 1, 2, 3 …, and then want to move to order 2, 1, 3 …
This would lead to a discontinuous jump in your loss function, which would
therefore be piece-wise constant, which is not (usefully) differentiable.

You can’t backpropagate through argsort() (because, as discussed above,
“order” is not differentiable). Note that argsort() returns a tensor with
dtype = torch.int64 and such tensors never carry requires_grad = True
(because integers are not usefully differentiable).

(Hypothetically, you could have trainable weights that somehow “blend”
together your various orderings, but in your example you would have
factorial (8) = 40320 possible orderings which would be impractically
large.)

Best.

K. Frank

1 Like