Subclassing torch.autograd.Function to write a custom loss function and maintain gradients

I am trying to use gradient descent to learn the affine matrix values necessary to minimize a cost function.

In short, I have n pairs of images, and a tensor of size [n, 3, 3], one identity matrix for each image. For each pair of images, I use the corresponding tensors to create an affine transformation of these images. I then calculate a full-normalized covariance matrix of the two images. Importantly, I cannot use pytorch for this step. The cost is the difference between the location of the largest values in the ncov matrix and the center of the ncov array, which would be a perfect overlay. This difference (loss) is returned as a single scalar value.

Because I can’t use pytorch to calculate the ncov, I break the computational graph and cannot calculate the gradients. I gather that I can make a subclass of torch.autograd.Function (described here), however, it is very unclear to me specifically how to write the backward() method to ensure that the new Function works properly with the autograd engine.

It seems like the values that should be returned from the backward() method depend heavily on which non-torch computations are used. Am I overthinking this? If not, is there a good reference simplifying how forward/backward should be structured for unconventional loss functions like this?

There’s no easy way to do this unfortunately. You’ll need to manually find the derivative of the non-torch operations you used.

You’re allowed to write the backward in non-torch operations, and you can use torch.autograd.gradcheck to help you figure out if your gradients are correct or not.

Ok good to know, thank you! Are there any resources for manually finding the derivatives of the non-torch operations I use?

I don’t have any specific recommended ones in mind, but from googling “backpropagation derivation” I found A beginner’s guide to deriving and implementing backpropagation | by Pranav Budhwant | binaryandmore | Medium. It would probably be good to look at a couple tutorials and see what explanation works best for you.

If you have any specific questions, happy to answer them here as well.

1 Like

ok I will see what I can do, thank you!