I am working on an architecture which requires applying a rigid transformation to a non-square image.
To this end, I am using a spatial transformer module. However, applying a (rigid) rotation to a non-square image inevitable produces distortion, as can be seen in this image:
Is it possible to avoid this issue without explicitly padding the input to make it square, and then cropping the result? Are there any parameters of, e.g., grid_sample
I am missing, which could help get this job done?
Example of a correct rotation:
Here is a self-contained example of the problem I am facing. The rotated output of the network ends up distorted, whereas I would prefer it to get transformed in a rigid manner, even if means some clipping occurs:
from math import sin, cos, pi
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
# Enable this if inside notebook
#%matplotlib inline
DEG_TO_RAD = math.pi / 180.0
angle_deg = 45
angle_rad = angle_deg * DEG_TO_RAD
image_size = (1, 1, 480, 640)
rotation = torch.tensor([
[ cos(angle_rad), sin(angle_rad), 0],
[-sin(angle_rad), cos(angle_rad), 0],
]).unsqueeze(0).cuda()
image = torch.zeros(image_size).float().cuda()
image[:, :, :, 280:340] = 255
# Uncomment this to get the right result
# image_pad = torch.zeros((1, 1, 640, 640)).float().cuda()
# image_pad[:, :, 80:-80, :] = image
image_pad=image
grid = F.affine_grid(rotation, size=image_pad.size())
rotated = F.grid_sample(image_pad, grid)
plt.figure()
plt.imshow(image[0, 0].data.cpu().numpy())
plt.title("Original")
plt.figure()
rotated = rotated[:, :, 80:-80, :]
plt.imshow(rotated[0, 0].data.cpu().numpy())
plt.title("Rotated by {} degrees".format(angle_deg))
Bottom line: Can you rotate a non-square image without distorting it using the existing STN functionality without having to pad the image with zeros, or implementing a custom version of grid_sample
?