I want to create a model that contains a network that learns to estimate rotation angles for individual data points.
However, with my current implementation, the Gradients of the angle embedding network become None.
Based on a suggestion here: Differentiable affine transforms with grid_sample
or use
torch.cat
ortorch.stack
to createtheta
in theforward
method from the parameters.
I tried using .stack() and .cat() on the list of rotation matrices; however my gradients still become None.
I display the gradients after the backward computation with this command:
print([(param.grad,name) for name, param in model.named_parameters()] )
and this is the output
... ,(None, 'angle.0.weight'), (None, 'angle.0.bias'), (None, 'angle.2.weight'), (None, 'angle.2.bias')]
This is the code that I’m trying to adapt for my purpose ( The original author of the code is Ghassen HAMROUNI, GHamrouni (Ghassen Hamrouni) · GitHub). The issue occurs in the Net Class in the method called stn().
# License: BSD
# Author: Ghassen Hamrouni
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
plt.ion() # interactive mode
Loading some data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Training dataset
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=64, shuffle=True, num_workers=4)
stn() is the method where I’m performing the operation that results in the None gradients
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.angle = nn.Sequential(
nn.Linear(28*28, 5),
nn.ReLU(True),
nn.Linear(5, 1)
)
# Spatial transformer network forward function
def stn(self, x):
angles = torch.arctan(self.angle(x.squeeze().reshape( (x.shape[0],28*28) )))*2
theta = torch.stack([torch.tensor([[[torch.cos(t), -torch.sin(t), 0.0], [torch.sin(t), torch.cos(t), 0.0]] for t in angles], requires_grad = True )]).squeeze()
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid, mode = "bilinear")
return x
def forward(self, x):
# transform the input
x = self.stn(x)
# Perform the usual forward pass
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net().to(device)
Training code
optimizer = optim.SGD(model.parameters(), lr=0.01)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
print([(param.grad,name) for name, param in model.named_parameters()] )
optimizer.step()
if batch_idx % 500 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
for epoch in range(1, 20 + 1):
train(epoch)
I’m happy for any suggestions on how to solve this issue. Thank you in advance!
SOLUTION:
I managed to solve the issue, this is how I changed the theta matrix:
theta = torch.stack( [ torch.stack([torch.stack([torch.cos(t).unsqueeze(dim=0), -torch.sin(t).unsqueeze(dim=0), torch.zeros(1)]), torch.stack([torch.sin(t).unsqueeze(dim=0), torch.cos(t).unsqueeze(dim=0), torch.zeros(1)])]) for t in angles] ).squeeze()