# Best way to constrain affine transformation to rotation

Hello everyone,

I know about Spatial Transformer Networks: https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html but I need restrict the transformation to only rotation.

Any suggestions on how to do this in a differentiable way?

If I try to predict the rotation angle and then build the array I can’t do black prop anymore (because in-place operation).

I will be very grateful if you can help me.

Could you post the code you are using to create the rotation matrix and which creates the inplace operation and the error?

Hi ptrblck,

Thank you very much beforehand for your help.

My current code for the networks is this:

import torch.nn as nn
import torch
import torch.nn.functional as F

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()

# The input is a batch x 2 x 224 x 224 size tensor
self.encoder = nn.Sequential(
nn.Conv2d(2, 64, 3, 1, 1, bias=False),
nn.ELU(),
nn.AvgPool2d(2, 2, 0),

nn.Conv2d(64, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(128, momentum=0.9, eps=1e-5),
nn.ELU(),

nn.Conv2d(128, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(128, momentum=0.9, eps=1e-5),
nn.ELU(),
nn.AvgPool2d(2, 2, 0),

nn.Conv2d(128, 1, 3, 1, 1)
) # batch x 1 x 56 x 56

self.fc_theta = nn.Sequential(
nn.Linear(56 * 56, 32),
nn.ReLU(True),
nn.Linear(32, 1)
)

def rotation(self, image, theta):
# rotation matrix
rot_matrix = torch.tensor([
torch.cos(theta), -torch.sin(theta), 0.,
torch.sin(theta), torch.cos(theta), 0.
], dtype=torch.float).view(1, 2, 3)

# Apply the rotation
grid = F.affine_grid(rot_matrix, image.size(), align_corners=False)
x = F.grid_sample(image, grid, align_corners=False)

return x

def forward(self, x, y):
batch_size = x.size(0)
xy = torch.cat((x, y), dim=1)

out = self.encoder(xy)
theta = self.fc_theta(out.view(batch_size, -1))

x_rot = self.rotation(x, theta)
return x_rot

but when I try backprop the loss:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I suppose that this is because rot_matrix but if I try:

rot_matrix = torch.tensor([
torch.cos(theta), -torch.sin(theta), 0.,
torch.sin(theta), torch.cos(theta), 0.