I want to compute the laplacian of a batch satellite images of a single location, treating each channel of each image as it’s own scalar field. The Laplacian is to be taken with respect to a single shared frame of reference. In other words, I may have a set of 10 images taken from 10 different angles, so we have developed a way to rectify them all into a single coordinate system.
The function I want to take the gradient of is this (psuedo-code):
def get_pixel_values(local_coords, images, xforms):
geo_coords = convert_local_to_lat_lon(local_coords)
pixel_coords = get_pixel_locations(geo_coords, xforms)
scaled_pixel_coords = scale(pixel_coords)
pixel_values = grid_sample(images, grid=scaled_pixel_coords)
My problem is that this function outputs more than a singleton tensor required by the functorch.grad
docs:
I’ve created a version of the above function that takes only a single coordinate, single image and single transform to meet vmap
requirements:
def img_grad(local_coords, images, xforms)
def single_input(single_coord, single_image, single_transform):
return get_pixel_values(single_coord[None], single_image[None], single_xform[None])
but I’m stuck at meeting the grad
requirements of a single element tensor output…
Any help appreciated. I’m sorry I can’t provide more than pseudocode.