Differentiable and learnable rotations with grid_sample

I want to create a model that contains a network that learns to estimate rotation angles for individual data points.

However, with my current implementation, the Gradients of the angle embedding network become None.

Based on a suggestion here: Differentiable affine transforms with grid_sample

or use torch.cat or torch.stack to create theta in the forward method from the parameters.

I tried using .stack() and .cat() on the list of rotation matrices; however my gradients still become None.
I display the gradients after the backward computation with this command:

print([(param.grad,name) for name, param in model.named_parameters()] )

and this is the output

... ,(None, 'angle.0.weight'), (None, 'angle.0.bias'), (None, 'angle.2.weight'), (None, 'angle.2.bias')]

This is the code that I’m trying to adapt for my purpose ( The original author of the code is Ghassen HAMROUNI, GHamrouni (Ghassen Hamrouni) · GitHub). The issue occurs in the Net Class in the method called stn().

# License: BSD
# Author: Ghassen Hamrouni

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

plt.ion()   # interactive mode

Loading some data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=True, download=True,
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False, transform=transforms.Compose([
        transforms.Normalize((0.1307,), (0.3081,))
    ])), batch_size=64, shuffle=True, num_workers=4)

stn() is the method where I’m performing the operation that results in the None gradients

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        self.angle = nn.Sequential(
            nn.Linear(28*28, 5),
            nn.Linear(5, 1)

    # Spatial transformer network forward function
    def stn(self, x):

        angles = torch.arctan(self.angle(x.squeeze().reshape( (x.shape[0],28*28) )))*2
        theta = torch.stack([torch.tensor([[[torch.cos(t), -torch.sin(t), 0.0], [torch.sin(t), torch.cos(t), 0.0]] for t in angles], requires_grad = True )]).squeeze()
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid, mode = "bilinear")

        return x

    def forward(self, x):
        # transform the input
        x = self.stn(x)

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = Net().to(device)

Training code

optimizer = optim.SGD(model.parameters(), lr=0.01)

def train(epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        output = model(data)
        loss = F.nll_loss(output, target)
        print([(param.grad,name) for name, param in model.named_parameters()] )
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

for epoch in range(1, 20 + 1):

I’m happy for any suggestions on how to solve this issue. Thank you in advance!

I managed to solve the issue, this is how I changed the theta matrix:

theta = torch.stack( [ torch.stack([torch.stack([torch.cos(t).unsqueeze(dim=0), -torch.sin(t).unsqueeze(dim=0), torch.zeros(1)]), torch.stack([torch.sin(t).unsqueeze(dim=0), torch.cos(t).unsqueeze(dim=0), torch.zeros(1)])]) for t in angles] ).squeeze()