Creating an nn.Module that uses the model parameters indirectly gets grad=None

Hello everyone, I’m trying to create a new layer: “Permutation”.
I want it to learn a permutation and apply it on the input.
Unlike torch.permute, I want the permutation to permute pixels in each channel, and not the tensor dimensions.

The algorithm I’m applying:

  1. Use torch.argsort on the layer weights first row to get a permutation vector
  2. Create a Permutation matrix that have ones in the column specified by the permutation vector, and zeros everywhere else
  3. Multiply the input with the permutation matrix

When Running the following code, I get grad=None, even if all the intermediate tensors have grad_fn and their requires_grad=True.
What am I doing wrong?

Here Is The class I wrote:

class Permutation(nn.Module):

      class Permutation(nn.Module):
        def __init__(self, size):
            super(Permutation, self).__init__()
            self.size = size ** 2
            self.weight = torch.nn.Parameter(torch.empty(self.size, self.size))
        def forward(self, x):
            weight = self.weight.clone()
            permutation_vector = torch.argsort(weight[0]).view(-1, 1)
            weight = weight.masked_fill(weight <= weight.max(), 0.0)
            weight = torch.scatter(weight, index=permutation_vector, dim=1, value=1)
            output = x @ weight
            return output

I used self.weight.clone() to avoid changing leaf variable (self.weight) in-place,
According to the suggestion here:

The rest of the code:
The Permutation model that uses Permutaion layer:

class PermNet(nn.Module):
    def __init__(self, input_features):
        self.input_features = input_features
        self.layers = torch.nn.Sequential(
            # nn.Tanh(),

    def forward(self, x):
        original_shape = x.shape
        output = x.flatten(start_dim=1)
        output = self.layers(output)
        return output.reshape(original_shape)

Here is my main training and evaluation loop:

import torch
from torch import nn
from tqdm import tqdm
from conv import Unet
from strict_perm_net import PermNet
from matplotlib import pyplot as plt
from torchvision import transforms, datasets
from non_rigid_transformation import non_rigid_transform

def mnist_data(size):
    compose = transforms.Compose(
         transforms.Resize((size, size)),
         transforms.Normalize(.5, .5)
    out_dir = r'C:\Users\ariel\Downloads\mnist'
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=False)

def evaluate(model, data_loader):
    eval_loss = 0
    for inputs, _ in data_loader:
        input_images = non_rigid_transform(inputs)
        outputs = model(input_images)
        iteration_eval_loss = criterion(outputs, inputs)
        eval_loss += iteration_eval_loss.item()

    eval_loss /= len(data_loader)
    return eval_loss

dataset = mnist_data(size=28)
dataset_single =, list(range(50)))
train_size = int(0.7 * len(dataset_single))
val_size = int(0.15 * len(dataset_single))
test_size = len(dataset_single) - train_size - val_size
train_dataset, val_dataset, test_dataset = \, [train_size, val_size, test_size])

train_loader =, batch_size=16, shuffle=True)
val_loader =, batch_size=16, shuffle=False)
test_loader =, batch_size=16, shuffle=False)
num_batches = len(train_loader)

model = PermNet(28)
# model = Unet()
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 100
train_loss = []
validation_loss = []

for _ in tqdm(range(epochs)):
    for idx, (img, _) in enumerate(train_loader):

        input_image = non_rigid_transform(img)
        prediction = model.forward(input_image)
        train_iter_loss = criterion(prediction, img)

    val_epoch_loss = evaluate(model, val_loader)

x_train = list(range(epochs * num_batches))
x_val = [epoch * num_batches for epoch in range(epochs)]
plt.plot(x_train, train_loss, color='g', label='train_loss')
plt.plot(x_val, validation_loss, color='b', label='validation_loss')

s = model.forward(input_image)

f, a = plt.subplots(3, 3)
for i in range(3):
    a[0][i % 3].imshow(input_image[i].squeeze().detach())
    a[1][i % 3].imshow(s[i].squeeze().detach())
    a[2][i % 3].imshow(img[i].squeeze().detach())


val_sample = next(iter(test_loader))[0]
aug = non_rigid_transform(val_sample)
pred = model.forward(aug)

f, a = plt.subplots(1, 3)


print("average test L1", evaluate(model, test_loader))

Hi Ariel!

The short story is that you will not be able to “learn” a permutation (at
least not with autograd and gradient-descent-based optimizers).

The problem is that a permutation is discrete – it’s not a continuous
entity with respect to which you can differentiate. (Consider encoding
your permutation with the indices of ones in a permutation matrix. Those
indices are integers and how do you differentiate with respect to an integer?)

Commenting just on your Permutation class, I can’t reproduce your
result in that I do not get grad = None. However, I do get a grad that
is identically zero, and hence not useful.

Two comments:

First, argsort() returns integer indices and is hence not differentiable. You
will see that permutation_vector does not have a grad_fn.

Second, although masked_fill() and scatter() are potentially differentiable,
in your case you fill weight entirely with zeros and then scatter in some
ones, neither of which is (usefully) differentiable.

Consider this script with an annotated version of your Permutation class:

import torch
print (torch.__version__)

_ = torch.manual_seed (2022)

class Permutation(torch.nn.Module):
    def __init__(self, size):
        super(Permutation, self).__init__()
        self.size = size ** 2
        self.weight = torch.nn.Parameter(torch.empty(self.size, self.size))   # uninitialized memory
        with torch.no_grad():                                                 # clean up garbage
            self.weight[self.weight.isnan()] = 666.0
            self.weight[self.weight.isinf()] = 999.0
    def forward(self, x):
        weight = self.weight.clone()
        print ('weight.grad_fn:', weight.grad_fn)                             # part of computation graph
        permutation_vector = torch.argsort(weight[0]).view(-1, 1)
        print ('permutation_vector.grad_fn:', permutation_vector.grad_fn)     # breaks computation graph
        weight = weight.masked_fill(weight <= weight.max(), 0.0)
        print ('weight.grad_fn:', weight.grad_fn)                             # still part of computation graph
        print ('torch.all (weight == 0.0):', torch.all (weight == 0.0))       # but identically zero
        weight = torch.scatter(weight, index=permutation_vector, dim=1, value=1)
        print ('weight.grad_fn:', weight.grad_fn)                             # still part of computation graph
        output = x @ weight
        print ('output.grad_fn:', output.grad_fn)                             # still part of computation graph
        return output

perm_mod = Permutation (3)
print ('perm_mod.weight.grad is None:', perm_mod.weight.grad is None)         # no grad yet

x = torch.randn (9)
output = perm_mod (x)
loss = ((output + torch.randn (9))**2).sum()                                  # some dummy loss

print ('perm_mod.weight.grad is None:', perm_mod.weight.grad is None)         # has a grad
wgrad = perm_mod.weight.grad
print ('torch.all (wgrad == 0.0):', torch.all (wgrad == 0.0))                 # but grad is identically zero

Here is its output:

perm_mod.weight.grad is None: True
weight.grad_fn: <CloneBackward0 object at 0x000001A3FE62BE80>
permutation_vector.grad_fn: None
weight.grad_fn: <MaskedFillBackward0 object at 0x000001A3FE62BE80>
torch.all (weight == 0.0): tensor(True)
weight.grad_fn: <ScatterBackward1 object at 0x000001A3FE62BE80>
output.grad_fn: <SqueezeBackward3 object at 0x000001A3FE62BE80>
perm_mod.weight.grad is None: False
torch.all (wgrad == 0.0): tensor(True)

(As an aside, your use of torch.empty() is flawed in that you never initialize
the memory allocated by torch.empty(). That is, torch.empty() just plops
your tensor down on some location in memory that contains whatever it
happened to contain previously. The contents of such memory are not
reproducible and can contain perverse values and things like nans. You
should somehow initialize weight, for example with something like

But again, all of this detail is moot because you can’t usefully train a discrete
object such as a permutation.


K. Frank