I have an image that I want to take the gradient of at each pixel. I am able to get the gradient at every point on the image using a wrapper around functional.grid_sample
torch.autograd.grad
, and now I want to compare that output to the output of functorch.jacrev
.
My wrapper works fine (assuming you are okay with garbage at the borders of each dimension), but using jacrev throws an error (entire stack trace at bottom). Any help in sorting this would be appreciated!
To reproduce:
I’m using torch 1.11.0+cu113 on Ubuntu.
Run this code in a notebook if you want to reproduce it:
import matplotlib.pyplot as plt
import numpy as np
import torch
import functorch as ft
from torch.nn import functional as F
def image_func_gen(image):
"""(I, C, H, W)"""
def _normalize_coordinates(row_col_coords):
"""zero based, (P, 2 [row, col])"""
image_dims = torch.tensor(image.shape)[-2:]
ratio_spanned = (row_col_coords) / (image_dims - 1)
scaled = ratio_spanned * 2
normalized_coords = -1 + scaled
return normalized_coords
def image_func(coords):
"""(P, 2 [row, col])"""
normalized_coords = _normalize_coordinates(coords)
normalized_coords = normalized_coords.flip(dims=(-1,))
grid = normalized_coords[None, :, None, :]
return F.grid_sample(input=image, grid=grid, align_corners=True)
return image_func
def image_gradient(coords, image):
# gradient at edges (last row/col) will be set to zero
padded_image = torch.nn.functional.pad(image, (0, 1, 0, 1), mode='replicate')
image_func = image_func_gen(padded_image)
pixel_values = image_func(coords).squeeze()
gradients = torch.autograd.grad(pixel_values, coords, torch.ones(len(pixel_values)))[0]
return gradients
# Set up a simple image with identical gradient at every point:
height = 5.
width = 5.
h, w = int(height), int(width)
image = torch.arange(0, width).expand(h, w)
# create function to model image
image_function = image_func_gen(image[None, None])
# sample the function. output should be identical to original image.
x_range = torch.linspace(0., width - 1, w, requires_grad=True)
y_range = torch.linspace(0., height - 1, h, requires_grad=True)
coords = torch.cartesian_prod(y_range, x_range)
output = image_function(coords)
# now take the gradient:
gradient = image_gradient(coords, image[None, None])
# and compare to the Jacobian: (ERROR HERE)
ft.jacrev(image_function, argnums=0)(coords)
Results in: