C++ API extending autograd functions

Using the Python backend one can create custom functions by subclassing from Function and implementing the forward and backward methods, as explained in https://pytorch.org/docs/stable/notes/extending.html

However in my work I am trying to use Pytorch using the C++ API which is working great for the moment. Nevertheless the implementation of custom functions (so extending torch.autograd) is not clear using the C++ API. Creating a C++ class which inherits from torch::autograd::Function and implements the forward and backward passes seems to not be the solution as the Function class in function.h seems to lack the forward and backward methods to begin with.

What would be the best way to write custom functions using the C++ API? I understand that this new frontend is only in an initial stage and we can expect changes to it but it would be great if there could be a way (even if more verbose) to somehow create custom functions in which we can define our own computation and gradients.

1 Like

Hi,

Functions in cpp are quite different.
A Function only has an apply method that represent the forward pass and should return as output variables with the proper .grad_fn attached.

For example, the DelayedError Function here is just a noop that will raise an error if it’s backward is called.
You can see it does nothing during the forward and attach the Error Function to it’s outputs using the handy “wrap_outputs”.
For the forward pass, the function needs to be instantiates first and then the forward pass will happen when the .apply method is called (Note that it can only be called once). For backward methods, you create the instance and attach it to your output and the apply will be called once by the autograd engine when it’s needed.
You can see here for example how Gather is used in pure cpp.

If you need this Function to be available in python, you can see here how we do it.

1 Like

@albanD

Documentation on autograd says that Variable class is deprecated. Should I still use it in order to implement functions in cpp, same as the examples you’ve provided use it?

Variable has been deprecated in python but is still a thing in cpp.
You should use them the same way it is done in the example. Note that you are responsible for creating the proper computational graph in cpp !

@albanD

Thanks a lot! Your responses have been very helpful!

I’m unfortunately still rather unsure about how the whole autograd infrastructure links together in C++, I’ve spent the day trying to understand the DelayedError and Gather functions, but seem to be missing the bigger picture.

Looking at the python tutorial, one finds the MulConstant example (repeated below):

class MulConstant(Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        ctx.constant = constant
        return tensor * constant

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output * ctx.constant, None

As an example, how would one implement this using the C++ frontend?

Thank you very much for your help.


Edit:

I compiled PyTorch from source and reverse-engineered the inbuilt operations. Take a look at this minimum working example of the pow( ) function.

1 Like