Select in images from 2d coordinates and compute gradient with respect to 2d coordinates


I have in my code 2d coordinates computed from variables I want to optimize. I have some images (which are distance maps and are considered ground truth so constant) and I wish to select in those images the pixels by indexing with my 2d coordinates and get gradient with respect to my variables.
To my understanding, functions like select_index and gather only compute gradient with respect to the input and not the index.

Is there a way to do this in pytorch please ?