Functorch.grad and vmap

I want to compute the laplacian of a batch satellite images of a single location, treating each channel of each image as it’s own scalar field. The Laplacian is to be taken with respect to a single shared frame of reference. In other words, I may have a set of 10 images taken from 10 different angles, so we have developed a way to rectify them all into a single coordinate system.

The function I want to take the gradient of is this (psuedo-code):

def get_pixel_values(local_coords, images, xforms):
    geo_coords = convert_local_to_lat_lon(local_coords)
    pixel_coords = get_pixel_locations(geo_coords, xforms)
    scaled_pixel_coords = scale(pixel_coords)
    pixel_values = grid_sample(images, grid=scaled_pixel_coords)

My problem is that this function outputs more than a singleton tensor required by the functorch.grad docs:

I’ve created a version of the above function that takes only a single coordinate, single image and single transform to meet vmap requirements:

def img_grad(local_coords, images, xforms)

    def single_input(single_coord, single_image, single_transform):
        return get_pixel_values(single_coord[None], single_image[None], single_xform[None])

but I’m stuck at meeting the grad requirements of a single element tensor output…

Any help appreciated. I’m sorry I can’t provide more than pseudocode.