Affine_grid and grid_sample: why is my image not rotated in the right direction?

Hi,

In PyTorch 1.8, I want to break down the constituents of a Spatial Transformer, in particular see how affine_grid and grid_sample work. So I came up with the following piece of code where I apply a rotation of angle pi/8 to an image:

import matplotlib.pyplot as plt
from matplotlib.image import imread
import numpy as np
import torch
import torch.nn.functional as F

img = imread("food.jpg")
w, h, _ = img.shape
img = img / 255.
print("image shape: ", img.shape)
x = torch.from_numpy(img)
x = x.type(torch.FloatTensor)
# add a batch size N=1
x = torch.unsqueeze(x,0)
# x is of size NxHxWx3, we need Nx3xHxW 
x = x.permute(0,3,1,2)
# affine transformation, size Nx2x3
angle = np.pi/8
theta = [ [ np.cos(angle)*w/h, -np.sin(angle)    , 0. ],
          [ np.sin(angle),      np.cos(angle)*w/h, 0. ] ]
theta = torch.Tensor(theta).unsqueeze(0)
# grid is of size NxHxWx2
grid = F.affine_grid(theta, x.size(), align_corners=False)
x = F.grid_sample(x, grid, align_corners=False)
# x is of size Nx3xHxW, we need HxWx3 
x = x.squeeze(0)
x = x.permute(1,2,0)

fig = plt.figure(figsize=(16,8), facecolor='white')
ax = fig.add_subplot(1,3,1)
ax.imshow(img)
ax.axis('off')
ax.set_title("Source (U)", fontsize=16)
ax = fig.add_subplot(1,3,2)
ax.imshow(x)
ax.axis('off')
ax.set_title("Target (V)", fontsize=16)
fig.set_tight_layout(True)
fig.savefig('foo.png')
ax = fig.add_subplot(1,3,3)
ax.set_title("Grid", fontsize=16)
grid = grid.squeeze(0)
grid = grid.permute(2,0,1)
# grid is now of size 2xHxW
grid = grid.reshape(2,-1).numpy()
ax.scatter(grid[0,::100],grid[1,::100], s=0.01) 
plt.show()

and I got

However, if I understand correctly, the grid (on the right) is put on the lattice of the input feature maps (on the left) and afterwards a bilinear interpolation is carried out. If so, I would expect the bowl to be rotated in the opposite direction (i.e. tilted to the right downwards). Why isn’t it so?

thanks for the help.

Hello !

On your scatter plot of the grid you need to add:

ax.invert_yaxis()

Because (-1, -1) is the top left corner, and (1,1) the bottom right corner so y axis is inverted.

See grid_sample documentation: torch.nn.functional.grid_sample — PyTorch 1.11.0 documentation

grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1] . For example, values x = -1, y = -1 is the left-top pixel of input , and values x = 1, y = 1 is the right-bottom pixel of input .