With the following code, you should get the same output but different gradient between pytorch STN function and my own.
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))
])), batch_size=64, shuffle=True, num_workers=4)
# Get a batch of training data
data = next(iter(test_loader))[0].to(device)
input_tensor = data.cpu()
theta_base = Variable(torch.Tensor([0.8, 0, 0.0, 0.0, 0.8, 0]).view(1,2,3),requires_grad=True)
theta = theta_base.expand(data.size(0),-1,-1).contiguous()
# compute STN from pytorch function
grid = F.affine_grid(theta, data.size())
transformed_input_tensor1 = F.grid_sample(data, grid)
# compute STN from custom function
transformed_input_tensor2 = spatial_transformer_network(data, theta)
print('output diff: %f' % (transformed_input_tensor1-transformed_input_tensor2).sum().item())
# check the gradient
loss = transformed_input_tensor1.sum()
loss.backward(retain_graph=True)
grad1 = theta_base.grad.clone()
loss = transformed_input_tensor2.sum()
loss.backward(retain_graph=True)
grad2 = theta_base.grad
print('grad diff: %f' % (grad1-grad2).sum().item())
in_grid = convert_image_np(
torchvision.utils.make_grid(input_tensor))
out_grid1 = convert_image_np(
torchvision.utils.make_grid(transformed_input_tensor1.detach().cpu()))
out_grid2 = convert_image_np(
torchvision.utils.make_grid(transformed_input_tensor2).detach().cpu())
out_grid3 = convert_image_np(
torchvision.utils.make_grid(torch.abs(transformed_input_tensor2-transformed_input_tensor1).detach().cpu()))
# Plot the results side-by-side
f, axarr = plt.subplots(1, 4)
axarr[0].imshow(in_grid)
axarr[0].set_title('Dataset Images')
axarr[1].imshow(out_grid1)
axarr[1].set_title('Transformed (Pytorch Function)')
axarr[2].imshow(out_grid2)
axarr[2].set_title('Transformed (Custom)')
axarr[3].imshow(out_grid3)
axarr[3].set_title('Transformed Diff')
plt.ioff()
plt.show()
This is the custom spatial transformer functions.
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import copy
import math
from torch.nn import init
from torch.autograd import Variable
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
from IPython import embed
from IPython.terminal.embed import InteractiveShellEmbed
plt.ion() # interactive mode
def spatial_transformer_network(input_fmap, theta, out_dims=None):
"""
Spatial Transformer Network layer implementation as described in [1].
The layer is composed of 3 elements:
- localisation_net: takes the original image as input and outputs
the parameters of the affine transformation that should be applied
to the input image.
- affine_grid_generator: generates a grid of (x,y) coordinates that
correspond to a set of points where the input should be sampled
to produce the transformed output.
- bilinear_sampler: takes as input the original image and the grid
and produces the output image using bilinear interpolation.
Input
-----
- input_fmap: output of the previous layer. Can be input if spatial
transformer layer is at the beginning of architecture. Should be
a tensor of shape (B, H, W, C).
- theta: affine transform tensor of shape (B, 6). Permits cropping,
translation and isotropic scaling. Initialize to identity matrix.
It is the output of the localization network.
Returns
-------
- out_fmap: transformed input feature map. Tensor of size (B, H, W, C).
"""
# grab input dimensions
B, C, H, W = input_fmap.size()
# reshape theta to (B, 2, 3)
theta_b = theta.view(B, 2, 3)
# generate grids of same size or upsample/downsample if specified
if out_dims:
out_H = out_dims[0]
out_W = out_dims[1]
batch_grids = affine_grid_generator(out_H, out_W, theta_b)
else:
batch_grids = affine_grid_generator(H, W, theta_b)
x_s = batch_grids[:, 0, :, :]
y_s = batch_grids[:, 1, :, :]
out_fmap = bilinear_sampler(input_fmap, x_s, y_s)
return out_fmap
def affine_grid_generator(height, width, theta):
"""
This function returns a sampling grid, which when
used with the bilinear sampler on the input feature
map, will create an output feature map that is an
affine transformation [1] of the input feature map.
Input
-----
- height: desired height of grid/output. Used
to downsample or upsample.
- width: desired width of grid/output. Used
to downsample or upsample.
- theta: affine transform matrices of shape (num_batch, 2, 3).
For each image in the batch, we have 6 theta parameters of
the form (2x3) that define the affine transformation T.
Returns
-------
- normalized gird (-1, 1) of shape (num_batch, 2, H, W).
The 2nd dimension has 2 components: (x, y) which are the
sampling points of the original image for each point in the
target image.
"""
# grab batch size
num_batch = theta.size()[0]
# create normalized 2d grid
x = torch.linspace(-1.0, 1.0, steps=width).to(device)
y = torch.linspace(-1.0, 1.0, steps=height).to(device)
x_t_flat = x.repeat(height).view(-1)
y_t_flat = y.view(-1, 1).repeat(1, width).view(-1)
# reshape to [x_t, y_t, 1] - (homogeneou form)
ones = torch.ones_like(x_t_flat)
sampling_grid = torch.stack([x_t_flat, y_t_flat, ones])
# repeat grid num_batch times
sampling_grid = sampling_grid.unsqueeze(0).repeat(num_batch, 1, 1)
# cast to float
theta = theta.float()
sampling_grid = sampling_grid.float()
# print("sampling_grid 1:", sampling_grid.shape)
# transform the sampling grid - batch multiply
# batch grid has shape (num_batch, 2, H*W)
batch_grids = torch.matmul(theta, sampling_grid)
# print("batch_grids 2:", batch_grids.shape)
# reshape to (num_batch, 2, H, W,)
batch_grids = batch_grids.view(num_batch, 2, height, width)
# print("batch_grids 3:", batch_grids.shape)
return batch_grids
def get_pixel_value(img, x, y):
"""
Utility function to get pixel value for coordinate
vectors x and y from a 4D tensor image.
Input
-----
- img: tensor of shape (B, H, W, C)
- x: flattened tensor of shape (B*H*W, )
- y: flattened tensor of shape (B*H*W, )
Returns
-------
- output: tensor of shape (B, H, W, C)
"""
# prepare img params
B, C, H, W = img.size()
batch_idx = torch.arange(0, B)
batch_idx = batch_idx.view(-1, 1, 1)
b = batch_idx.repeat(1, H, W)
indices = torch.stack([b, y, x], dim=3)
return torch.gather(img, 0, Variable(indices.long()).cuda())
# return x.gather(indices)
def bilinear_sampler(img, x, y):
"""
Performs bilinear sampling of the input images according to the
normalized coordinates provided by the sampling grid. Note that
the sampling is done identically for each channel of the input.
To test if the function works properly, output image should be
identical to input image when theta is initialized to identity
transform.
Input
-----
- img: batch of images in (B, H, W, C) layout.
- grid: x, y which is the output of affine_grid_generator.
Returns
-------
- interpolated images according to grids. Same size as grid.
"""
# prepare img params
B, C, H, W = img.size()
zero = 0.
# cast indices as float32 (for rescaling)
x = x.float().to(device)
y = y.float().to(device)
# rescale x and y to [0, WorH]
x = 0.5 * ((x + 1.0) * float(W-1))
y = 0.5 * ((y + 1.0) * float(H-1))
x = x.contiguous().view(-1)
y = y.contiguous().view(-1)
# grab 4 nearest corner points for each (x_i, y_i)
# we need a rectangle around the point of interest
x0 = x.floor().int()
x1 = x0 + 1
y0 = y.floor().int()
y1 = y0 + 1
max_x = W - 1
max_y = H - 1
x0_clamp = torch.clamp(x0, 0, max_x)
x1_clamp = torch.clamp(x1, 0, max_x)
y0_clamp = torch.clamp(y0, 0, max_y)
y1_clamp = torch.clamp(y1, 0, max_y)
dim2 = W
dim1 = W * H
base = None
if img.is_cuda:
base = (dim1 * torch.arange(B).int()).view(B, 1).expand(B, H * W).contiguous().view(-1).cuda()
else:
base = (dim1 * torch.arange(B).int()).view(B, 1).expand(B, H * W).contiguous().view(-1)
base_y0 = base + y0_clamp * dim2
base_y1 = base + y1_clamp * dim2
idx_y0_x0 = base_y0 + x0_clamp
idx_y0_x1 = base_y0 + x1_clamp
idx_y1_x0 = base_y1 + x0_clamp
idx_y1_x1 = base_y1 + x1_clamp
# (B,C,H,W) -> (B,H,W,C)
im_flat = img.permute(0,2,3,1).contiguous().view(-1, C)
i_y0_x0 = torch.gather(im_flat, 0, idx_y0_x0.unsqueeze(1).expand(-1,C).long())
i_y0_x1 = torch.gather(im_flat, 0, idx_y0_x1.unsqueeze(1).expand(-1,C).long())
i_y1_x0 = torch.gather(im_flat, 0, idx_y1_x0.unsqueeze(1).expand(-1,C).long())
i_y1_x1 = torch.gather(im_flat, 0, idx_y1_x1.unsqueeze(1).expand(-1,C).long())
# Check the out-of-boundary case.
x0_valid = (x0 <= max_x) & (x0 >= 0)
x1_valid = (x1 <= max_x) & (x1 >= 0)
y0_valid = (y0 <= max_y) & (y0 >= 0)
y1_valid = (y1 <= max_y) & (y1 >= 0)
x0 = x0.float()
x1 = x1.float()
y0 = y0.float()
y1 = y1.float()
w_y0_x0 = ((x1 - x) * (y1 - y) * (x1_valid * y1_valid).float()).unsqueeze(1)
w_y0_x1 = ((x - x0) * (y1 - y) * (x0_valid * y1_valid).float()).unsqueeze(1)
w_y1_x0 = ((x1 - x) * (y - y0) * (x1_valid * y0_valid).float()).unsqueeze(1)
w_y1_x1 = ((x - x0) * (y - y0) * (x0_valid * y0_valid).float()).unsqueeze(1)
out = w_y0_x0*i_y0_x0+w_y0_x1*i_y0_x1+w_y1_x0*i_y1_x0+w_y1_x1*i_y1_x1
return out.view(B, H, W, C).permute(0,3,1,2).contiguous()