Differentiable Permutation

Hello, I was wondering if there is a way to perform a differentiable reordering of items such that the model learns the optimal ordering of items. Specifically, I have an input tensor of shape [batch_size, seq_length, dim] from a VIT block and I want to shuffle the items along the second dimension such that the output tensor is of the same shape just with the items in a different order.

This is what I have so far

import torch
import torch.nn as nn
import torch.nn.functional as F
from fast_soft_sort.pytorch_ops import soft_rank # differentiable ranking taken from https://github.com/google-research/fast-soft-sort/tree/master

class Shuffle(nn.Module):
    def __init__(self, L):
        """
        L: Length of the sequence
        D: Dimension of each vector in the sequence
        """
        super(Shuffle, self).__init__()
        self.weights = nn.Parameter(torch.randn(L)).unsqueeze(0) * 100

    def forward(self, x):
        batch_size, seq_length, dim = x.shape
        expanded_scores = self.weights.unsqueeze(-1)
        indices = soft_rank(self.weights).squeeze(0).long() - 1 # detach here
        indices = indices.view(1, -1, 1).expand_as(x)
        x = torch.gather(x, dim = 1, index = indices)
        return x

The idea I had was for the model to learn self.weights, use a differentiable ranking algorithm and then gather these ranked items as indices. The problem is that casting to long() detaches gradients.

I came across this thread for selecting rather than shuffling. I am unsure if its possible to apply that to my use case.

All help is greatly appreciated, thanks!

In order to implement something like this, here is what I would do:

  1. Have some non-model way of determining what is optimal. This could be accomplished by testing all possible permutations and having some pre-determined algorithm for scoring them based on whatever parameters you desire, or by manually scoring them via whatever you consider to be optimal. With this, you can designate targets. Targets should be from 0 to n-1 where n is the possible types of permutations.
  2. Create a convolution network of appropriate dimensions for the inputs. Output of the model should be of size n classes.
  3. Train the model on the inputs and targets.

This way, you can just treat it as a classification problem using cross entropy loss with logits between the outputs and targets.

Aside from gradients being detached when changing to .long() issue, the model that you have proposed has no means of determining the loss. Also, in order to get some type of classification loss, you want an output of logits or probabilities.

Thanks for the reply. This was meant to be part of a ViT so
ViT: [shuffle → Attn ] x N times.

Thats how I wanted to implement the model. Is there a solution that can fit this use case?

I would encourage you to have a look at some reinforcement learning models.

At any rate, you want the “Shuffle model” to take the same input as the ViT, and reduce it down to classes. Convolution operations are your best bet.

The class output would be the permutation applied. After the permutation is made on the original image, pass that into the ViT. Then get the loss of the ViT. Set the target classes to 0s(assuming logits as outputs) and then alter the class choice actually made to be (1 - ViT_loss). Keep in mind, you need to get the loss before the mean is taken across the batch, and detach a copy of this from the ViT. The batches with lower loss, will get a stronger signal for that choice fed back as a target.

Use l1loss between the ‘Shuffle model’ outputs and the targets.

Additionally, you’ll want a greedy epsilon function that randomly chooses for the ‘Shuffle model’ initially. This ensures it explores all permutations. The greedy epsilon function should decay over time. See a DQN tutorial for examples of this.

You could also experiment with providing some normalization to the reward. Instead of (1 - ViT_loss), could do (ViT_loss_max - ViT_loss) or just normalize per batch: 1-(ViT_loss - ViT_loss_min)/(ViT_loss_max - ViT_loss_min) with an if statement that skips when ViT_loss_max == ViT_loss_min.