Hello everyone, I’m trying to create a new layer: “Permutation”.
I want it to learn a permutation and apply it on the input.
Unlike torch.permute, I want the permutation to permute pixels in each channel, and not the tensor dimensions.
The algorithm I’m applying:
- Use torch.argsort on the layer weights first row to get a permutation vector
- Create a Permutation matrix that have ones in the column specified by the permutation vector, and zeros everywhere else
- Multiply the input with the permutation matrix
When Running the following code, I get grad=None, even if all the intermediate tensors have grad_fn and their requires_grad=True.
What am I doing wrong?
Here Is The class I wrote:
class Permutation(nn.Module):
class Permutation(nn.Module):
def __init__(self, size):
super(Permutation, self).__init__()
self.size = size ** 2
self.weight = torch.nn.Parameter(torch.empty(self.size, self.size))
def forward(self, x):
weight = self.weight.clone()
permutation_vector = torch.argsort(weight[0]).view(-1, 1)
weight = weight.masked_fill(weight <= weight.max(), 0.0)
weight = torch.scatter(weight, index=permutation_vector, dim=1, value=1)
output = x @ weight
return output
I used self.weight.clone() to avoid changing leaf variable (self.weight) in-place,
According to the suggestion here:
The rest of the code:
The Permutation model that uses Permutaion layer:
class PermNet(nn.Module):
def __init__(self, input_features):
super().__init__()
self.input_features = input_features
self.layers = torch.nn.Sequential(
Permutation(input_features),
# nn.Tanh(),
)
def forward(self, x):
original_shape = x.shape
output = x.flatten(start_dim=1)
output = self.layers(output)
return output.reshape(original_shape)
Here is my main training and evaluation loop:
import torch
from torch import nn
from tqdm import tqdm
from conv import Unet
from strict_perm_net import PermNet
from matplotlib import pyplot as plt
from torchvision import transforms, datasets
from non_rigid_transformation import non_rigid_transform
def mnist_data(size):
compose = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize((size, size)),
transforms.Normalize(.5, .5)
])
out_dir = r'C:\Users\ariel\Downloads\mnist'
return datasets.MNIST(root=out_dir, train=True, transform=compose, download=False)
@torch.no_grad()
def evaluate(model, data_loader):
eval_loss = 0
for inputs, _ in data_loader:
input_images = non_rigid_transform(inputs)
outputs = model(input_images)
iteration_eval_loss = criterion(outputs, inputs)
eval_loss += iteration_eval_loss.item()
eval_loss /= len(data_loader)
return eval_loss
dataset = mnist_data(size=28)
dataset_single = torch.utils.data.Subset(dataset, list(range(50)))
train_size = int(0.7 * len(dataset_single))
val_size = int(0.15 * len(dataset_single))
test_size = len(dataset_single) - train_size - val_size
train_dataset, val_dataset, test_dataset = \
torch.utils.data.random_split(dataset_single, [train_size, val_size, test_size])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)
num_batches = len(train_loader)
model = PermNet(28)
# model = Unet()
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 100
train_loss = []
validation_loss = []
for _ in tqdm(range(epochs)):
model.train()
for idx, (img, _) in enumerate(train_loader):
optimizer.zero_grad()
input_image = non_rigid_transform(img)
prediction = model.forward(input_image)
train_iter_loss = criterion(prediction, img)
train_loss.append(train_iter_loss.data)
train_iter_loss.backward()
optimizer.step()
model.eval()
val_epoch_loss = evaluate(model, val_loader)
validation_loss.append(val_epoch_loss)
x_train = list(range(epochs * num_batches))
x_val = [epoch * num_batches for epoch in range(epochs)]
plt.plot(x_train, train_loss, color='g', label='train_loss')
plt.plot(x_val, validation_loss, color='b', label='validation_loss')
plt.legend()
plt.show()
s = model.forward(input_image)
f, a = plt.subplots(3, 3)
for i in range(3):
a[0][i % 3].imshow(input_image[i].squeeze().detach())
a[1][i % 3].imshow(s[i].squeeze().detach())
a[2][i % 3].imshow(img[i].squeeze().detach())
plt.suptitle("train")
plt.show()
val_sample = next(iter(test_loader))[0]
aug = non_rigid_transform(val_sample)
pred = model.forward(aug)
f, a = plt.subplots(1, 3)
a[0].imshow(aug[0].squeeze().detach())
a[1].imshow(pred[0].squeeze().detach())
a[2].imshow(val_sample[0].squeeze().detach())
plt.suptitle("test")
plt.show()
print("average test L1", evaluate(model, test_loader))