Extending torch.autograd


I have a really simple questions but I did not manage to find any answer on the web …
I would like to extend torch.autograd. Lets assume that the forward function takes as imput a 2 * N vector of control points and generate an output (for instance deformed image).
What exactly should exactly be the output of the backward function ?
Should it be a single variable containing all the color map derivatives and if such what is the correct format to use ?

I thank you very much for your help.


If your transform function is transform(control, image) -> transformed_image, where control is 2 * N the the backward should be:

transform_bwd(grad_transformed_image) -> grad_control, grad_image

If you don’t care about optimizing the image, you may not need to compute grad_image.

The grad_transformed_image will be the same shape as transformed_image and the computed grad_control should be 2 * N (the same shape as control).

One way to compute grad_control would be to first compute the jacobian of the transformation and then multiply by the incoming grad_transformed_image. The jacobian is a matrix of size (2*N) x (M) where M is the number of elements in the transformed image. Each entry in the matrix is a partial derivative: how does that control point affect a specific element in the transformed image. The grad_control would then be the matrix-vector product: jacobian x grad_transformed_image.

In practice, you probably don’t want to explicitly compute the Jacobian, since many of the entires will probably be 0. You can likely simplify the computation if you know properties of your transformation, such as if it’s local (control points only affect nearby pixels).

Thank you very much this is a really nice and clear answer !! Just to be sure I understand correctly, if we just consider one of the coordinate of a control point (for instance of a TPS), this point will modify the color map of the transformed image (ie locally change the appearance of the deformed image). The corresponding value in the 2*N grad_control matrix will basically correspond to (badly written as a finite difference):

sum( (sum (delta color map - current color map) / delta_step) * grad_transformed_image )

(should it be flatten instead of using the sum ?)

Thank you again