Bilinear sampling doesn't propagate gradient correctly in PyTorch0.4

After updating to PyTorch0.4, the bilinear sampling in the following code doesn’t give me correct gradients anymore (basically the code does exact same input as torch.nn.functional.grid_sample() and does the same functionality. I specified where the discrepancy happens between PyTorch04 and PyTorch03 putting FIXME comment in there). I confirmed that the exact same code works in PyTorch0.3.

Though PyTorch0.4 provides torch.nn.functional.grid_sample and it gives me correct gradients, I want to get the custom bilinear sampling working as I have a couple of different cases where I use the similar bilinear sampling scheme.

I really appreciate if anyone can help.

class SpatialTransformLayer(nn.Module):
    def __init__(self):
        super(SpatialTransformLayer, self).__init__()

    def _interpolate(self, im, xy):
        '''
        args:
            im: input image (B, C, H, W)
            xy: sampling coordinates (B, 2(x,y), outH, outW)

        return:
            (B, C, outH, outW) tensor
        '''

        B, C, H, W = im.size()
        _, _, out_H, out_W = xy.size()

        x, y = torch.chunk(xy, 2, dim=1)
        x = x.contiguous().view(-1)
        y = y.contiguous().view(-1)
        # print(x.size(), y.size())

        x0 = torch.floor(x).int()
        x1 = x0 + 1
        y0 = torch.floor(y).int()
        y1 = y0 + 1

        max_x = W - 1
        max_y = H - 1

        x0_clamp = torch.clamp(x0, 0, max_x)
        x1_clamp = torch.clamp(x1, 0, max_x)
        y0_clamp = torch.clamp(y0, 0, max_y)
        y1_clamp = torch.clamp(y1, 0, max_y)

        dim2 = W
        dim1 = W * H

        base = None
        if im.is_cuda:
            base = Variable((dim1 * torch.arange(B).int()).view(B, 1).expand(B, out_H * out_W).contiguous().view(-1).cuda())
        else:
            base = Variable((dim1 * torch.arange(B).int()).view(B, 1).expand(B, out_H * out_W).contiguous().view(-1))

        base_y0 = base + y0_clamp * dim2
        base_y1 = base + y1_clamp * dim2

        idx_y0_x0 = base_y0 + x0_clamp
        idx_y0_x1 = base_y0 + x1_clamp
        idx_y1_x0 = base_y1 + x0_clamp
        idx_y1_x1 = base_y1 + x1_clamp

        # (B,C,H,W) -> (B,H,W,C)
        im_flat = im.permute(0,2,3,1).contiguous().view(-1, C)
        i_y0_x0 = torch.gather(im_flat, 0, idx_y0_x0.unsqueeze(1).expand(-1,C).long())
        i_y0_x1 = torch.gather(im_flat, 0, idx_y0_x1.unsqueeze(1).expand(-1,C).long())
        i_y1_x0 = torch.gather(im_flat, 0, idx_y1_x0.unsqueeze(1).expand(-1,C).long())
        i_y1_x1 = torch.gather(im_flat, 0, idx_y1_x1.unsqueeze(1).expand(-1,C).long())
        
        # Check the out-of-boundary case.
        x0_valid = (x0 <= max_x) & (x0 >= 0)
        x1_valid = (x1 <= max_x) & (x1 >= 0)
        y0_valid = (y0 <= max_y) & (y0 >= 0)
        y1_valid = (y1 <= max_y) & (y1 >= 0)

        x0 = x0.float()
        x1 = x1.float()
        y0 = y0.float()
        y1 = y1.float()

        # FIXME: gradients through x, y here get broken
        w_y0_x0 = ((x1 - x) * (y1 - y) * (x1_valid * y1_valid).float()).unsqueeze(1)
        w_y0_x1 = ((x - x0) * (y1 - y) * (x0_valid * y1_valid).float()).unsqueeze(1)
        w_y1_x0 = ((x1 - x) * (y - y0) * (x1_valid * y0_valid).float()).unsqueeze(1)
        w_y1_x1 = ((x - x0) * (y - y0) * (x0_valid * y0_valid).float()).unsqueeze(1)

        output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
        mask = (x0_valid * x1_valid * y0_valid * y1_valid).float()

        # (B,H,W,C) -> (B,C,H,W)
        return output.view(B, out_H, out_W, C).permute(0,3,1,2).contiguous(), mask.view(B, out_H, out_W, 1).permute(0,3,1,2).contiguous()

    def forward(self, im, xy):
        '''
        args:
            im: (B, C, H, W) tensor
            xy: (B, 2(xy), outH, outW) tensor
        return:
            (B, C, outH, outW) tensor
        '''
        B, C, H, W = im.size()

        im, mask = self._interpolate(im, xy)

        return im, mask
1 Like

Could you provide a code snippet of how you’re computing the gradients?

With the following code, you should get the same output but different gradient between pytorch STN function and my own.

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False,
                   transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,),(0.3081,))
                   ])), batch_size=64, shuffle=True, num_workers=4)

# Get a batch of training data
data = next(iter(test_loader))[0].to(device)

input_tensor = data.cpu()

theta_base = Variable(torch.Tensor([0.8, 0, 0.0, 0.0, 0.8, 0]).view(1,2,3),requires_grad=True)
theta = theta_base.expand(data.size(0),-1,-1).contiguous()
# compute STN from pytorch function
grid = F.affine_grid(theta, data.size())
transformed_input_tensor1 = F.grid_sample(data, grid)

# compute STN from custom function
transformed_input_tensor2 = spatial_transformer_network(data, theta)

print('output diff: %f' % (transformed_input_tensor1-transformed_input_tensor2).sum().item())

# check the gradient
loss = transformed_input_tensor1.sum()
loss.backward(retain_graph=True)
grad1 = theta_base.grad.clone()
loss = transformed_input_tensor2.sum()
loss.backward(retain_graph=True)
grad2 = theta_base.grad
print('grad diff: %f' % (grad1-grad2).sum().item())

in_grid = convert_image_np(
    torchvision.utils.make_grid(input_tensor))

out_grid1 = convert_image_np(
    torchvision.utils.make_grid(transformed_input_tensor1.detach().cpu()))
out_grid2 = convert_image_np(
    torchvision.utils.make_grid(transformed_input_tensor2).detach().cpu())
out_grid3 = convert_image_np(
    torchvision.utils.make_grid(torch.abs(transformed_input_tensor2-transformed_input_tensor1).detach().cpu()))

# Plot the results side-by-side
f, axarr = plt.subplots(1, 4)
axarr[0].imshow(in_grid)
axarr[0].set_title('Dataset Images')

axarr[1].imshow(out_grid1)
axarr[1].set_title('Transformed (Pytorch Function)')

axarr[2].imshow(out_grid2)
axarr[2].set_title('Transformed (Custom)')

axarr[3].imshow(out_grid3)
axarr[3].set_title('Transformed Diff')

plt.ioff()
plt.show()

This is the custom spatial transformer functions.

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import copy
import math
from torch.nn import init
from torch.autograd import Variable
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader

from IPython import embed
from IPython.terminal.embed import InteractiveShellEmbed

plt.ion() # interactive mode

def spatial_transformer_network(input_fmap, theta, out_dims=None):
    """
        Spatial Transformer Network layer implementation as described in [1].
        The layer is composed of 3 elements:
        - localisation_net: takes the original image as input and outputs
          the parameters of the affine transformation that should be applied
          to the input image.
        - affine_grid_generator: generates a grid of (x,y) coordinates that
          correspond to a set of points where the input should be sampled
          to produce the transformed output.
        - bilinear_sampler: takes as input the original image and the grid
          and produces the output image using bilinear interpolation.
        Input
        -----
        - input_fmap: output of the previous layer. Can be input if spatial
          transformer layer is at the beginning of architecture. Should be
          a tensor of shape (B, H, W, C).
        - theta: affine transform tensor of shape (B, 6). Permits cropping,
          translation and isotropic scaling. Initialize to identity matrix.
          It is the output of the localization network.
        Returns
        -------
        - out_fmap: transformed input feature map. Tensor of size (B, H, W, C).
    """
    # grab input dimensions
    B, C, H, W = input_fmap.size()

    # reshape theta to (B, 2, 3)
    theta_b = theta.view(B, 2, 3)
    
    # generate grids of same size or upsample/downsample if specified
    if out_dims:
        out_H = out_dims[0]
        out_W = out_dims[1]
        batch_grids = affine_grid_generator(out_H, out_W, theta_b)
    else:
        batch_grids = affine_grid_generator(H, W, theta_b)

    x_s = batch_grids[:, 0, :, :]
    y_s = batch_grids[:, 1, :, :]
    
    out_fmap = bilinear_sampler(input_fmap, x_s, y_s)

    return out_fmap


def affine_grid_generator(height, width, theta):
    """
    This function returns a sampling grid, which when
    used with the bilinear sampler on the input feature
    map, will create an output feature map that is an
    affine transformation [1] of the input feature map.
    Input
    -----
    - height: desired height of grid/output. Used
      to downsample or upsample.
    - width: desired width of grid/output. Used
      to downsample or upsample.
    - theta: affine transform matrices of shape (num_batch, 2, 3).
      For each image in the batch, we have 6 theta parameters of
      the form (2x3) that define the affine transformation T.
    Returns
    -------
    - normalized gird (-1, 1) of shape (num_batch, 2, H, W).
      The 2nd dimension has 2 components: (x, y) which are the
      sampling points of the original image for each point in the
      target image.
    """

    # grab batch size
    num_batch = theta.size()[0]

    # create normalized 2d grid
    x = torch.linspace(-1.0, 1.0, steps=width).to(device)
    y = torch.linspace(-1.0, 1.0, steps=height).to(device)
    
    x_t_flat = x.repeat(height).view(-1)
    y_t_flat = y.view(-1, 1).repeat(1, width).view(-1)

    # reshape to [x_t, y_t, 1] - (homogeneou form)
    ones = torch.ones_like(x_t_flat)
    
    sampling_grid = torch.stack([x_t_flat, y_t_flat, ones])
    
    # repeat grid num_batch times
    sampling_grid = sampling_grid.unsqueeze(0).repeat(num_batch, 1, 1)

    # cast to float
    theta = theta.float()
    sampling_grid = sampling_grid.float()
    # print("sampling_grid 1:", sampling_grid.shape)

    # transform the sampling grid - batch multiply
    # batch grid has shape (num_batch, 2, H*W)
    batch_grids = torch.matmul(theta, sampling_grid)

    # print("batch_grids 2:", batch_grids.shape)

    # reshape to (num_batch, 2, H, W,)
    batch_grids = batch_grids.view(num_batch, 2, height, width)
    # print("batch_grids 3:", batch_grids.shape)

    return batch_grids
    
    
def get_pixel_value(img, x, y):
    """
    Utility function to get pixel value for coordinate
    vectors x and y from a  4D tensor image.
    Input
    -----
    - img: tensor of shape (B, H, W, C)
    - x: flattened tensor of shape (B*H*W, )
    - y: flattened tensor of shape (B*H*W, )
    Returns
    -------
    - output: tensor of shape (B, H, W, C)
    """
    
    # prepare img params
    B, C, H, W = img.size()
    
    batch_idx = torch.arange(0, B)
    batch_idx = batch_idx.view(-1, 1, 1)
    b = batch_idx.repeat(1, H, W)
    
    indices = torch.stack([b, y, x], dim=3)
    
    
    return torch.gather(img, 0, Variable(indices.long()).cuda())
    # return x.gather(indices)


def bilinear_sampler(img, x, y):
    """
    Performs bilinear sampling of the input images according to the
    normalized coordinates provided by the sampling grid. Note that
    the sampling is done identically for each channel of the input.
    To test if the function works properly, output image should be
    identical to input image when theta is initialized to identity
    transform.
    Input
    -----
    - img: batch of images in (B, H, W, C) layout.
    - grid: x, y which is the output of affine_grid_generator.
    Returns
    -------
    - interpolated images according to grids. Same size as grid.
    """

    # prepare img params
    B, C, H, W = img.size()

    zero = 0.

    # cast indices as float32 (for rescaling)
    x = x.float().to(device)
    y = y.float().to(device)

    # rescale x and y to [0, WorH]
    x = 0.5 * ((x + 1.0) * float(W-1))
    y = 0.5 * ((y + 1.0) * float(H-1))
    x = x.contiguous().view(-1)
    y = y.contiguous().view(-1)

    # grab 4 nearest corner points for each (x_i, y_i)
    # we need a rectangle around the point of interest
    x0 = x.floor().int()
    x1 = x0 + 1
    y0 = y.floor().int()
    y1 = y0 + 1

    max_x = W - 1
    max_y = H - 1

    x0_clamp = torch.clamp(x0, 0, max_x)
    x1_clamp = torch.clamp(x1, 0, max_x)
    y0_clamp = torch.clamp(y0, 0, max_y)
    y1_clamp = torch.clamp(y1, 0, max_y)

    dim2 = W
    dim1 = W * H

    base = None
    if img.is_cuda:
        base = (dim1 * torch.arange(B).int()).view(B, 1).expand(B, H * W).contiguous().view(-1).cuda()
    else:
        base = (dim1 * torch.arange(B).int()).view(B, 1).expand(B, H * W).contiguous().view(-1)
    
    base_y0 = base + y0_clamp * dim2
    base_y1 = base + y1_clamp * dim2

    idx_y0_x0 = base_y0 + x0_clamp
    idx_y0_x1 = base_y0 + x1_clamp
    idx_y1_x0 = base_y1 + x0_clamp
    idx_y1_x1 = base_y1 + x1_clamp

    # (B,C,H,W) -> (B,H,W,C)
    im_flat = img.permute(0,2,3,1).contiguous().view(-1, C)
    i_y0_x0 = torch.gather(im_flat, 0, idx_y0_x0.unsqueeze(1).expand(-1,C).long())
    i_y0_x1 = torch.gather(im_flat, 0, idx_y0_x1.unsqueeze(1).expand(-1,C).long())
    i_y1_x0 = torch.gather(im_flat, 0, idx_y1_x0.unsqueeze(1).expand(-1,C).long())
    i_y1_x1 = torch.gather(im_flat, 0, idx_y1_x1.unsqueeze(1).expand(-1,C).long())
    
    # Check the out-of-boundary case.
    x0_valid = (x0 <= max_x) & (x0 >= 0)
    x1_valid = (x1 <= max_x) & (x1 >= 0)
    y0_valid = (y0 <= max_y) & (y0 >= 0)
    y1_valid = (y1 <= max_y) & (y1 >= 0)

    x0 = x0.float()
    x1 = x1.float()
    y0 = y0.float()
    y1 = y1.float()

    w_y0_x0 = ((x1 - x) * (y1 - y) * (x1_valid * y1_valid).float()).unsqueeze(1)
    w_y0_x1 = ((x - x0) * (y1 - y) * (x0_valid * y1_valid).float()).unsqueeze(1)
    w_y1_x0 = ((x1 - x) * (y - y0) * (x1_valid * y0_valid).float()).unsqueeze(1)
    w_y1_x1 = ((x - x0) * (y - y0) * (x0_valid * y0_valid).float()).unsqueeze(1)

    out = w_y0_x0*i_y0_x0+w_y0_x1*i_y0_x1+w_y1_x0*i_y1_x0+w_y1_x1*i_y1_x1
    
    return out.view(B, H, W, C).permute(0,3,1,2).contiguous()