How to create a new torch.function() method?

In fact, I am trying to understand how and where is coded a function like torch.sum(). I guess the code is in the C folders somewhere, I was looking in Aten, but no way to find this thing.

For instance, I would like to add a simple thing: a reversed cumsum operation. It would be a great simplification for computing the returns in reinforcement learning, instead of:

returns = rewards.flip(dims).cumsum(dim).flip(dims)

This is only a simple example among hundreds.

Is there a generic way to contribute to these basic functions?

I think that depends how and where you plan on implementing it.

  • If it’s in python, just importing it in the should do the trick.
  • If it’s in cpp implemented within pytorch (not aten), then you need to implement it in torch/csrc in the correct file and then it can be added to the torch.xx python namespace by adding it here. Note that the function added there will receive PyObject. You can use it to unpack your data and then send them to a C function somewhere else in torch/csrc.
  • If it’s in cpp inside Aten, I’m less sure. From what I remember they are bound through _C._VariableFunctions. This is created in a templated file here that contains some special functions. The template is filled by functions in this file. The Aten functions are loaded from a declarations yaml file read here. This file is in the Aten repo here. From there, the cname for each function should be implemented in the libraries for all backend. The ones not implemented won’t be available. For example, the isSetTo method is implemented in TH, THC and THD and referenced in the declaration file as you can see in this search.

Note that for the Aten, I followed as I made it so there might be mistakes in it, but that should give you enough pointer to do what you want. Namely adding your method to the Declarations.cwrap file should create the method in torch. (if it actually returns something). Then you just need to add the code in the backend you are interested in.

1 Like

I’ll chime in with a bit of additional detail:
I think aten/src/ATen/native/native_functions.yaml is probably central and cumsum in particular is in the same directory in ReduceOps.cpp - but only a few wrappers for dealing with the accumulation type. The definition itself is - as @albanD said - in TH/THC. When you start a feature request, it might be worth discussing with the devs if that might be moved to native in that context, too.

I wrote up a short guide about how to go from Python function to the C++ implementation, maybe it can be useful as a start (it focuses on native, though, but I think the derivative bits may be relevant to cumsum). I’ll try to flesh out some bits about functions implemented directly in CPU and CUDA.

Best regards



Ok, thank you @albanD and @tom. Now I understand better how it’s done.

And the short guide is exactly what I was looking for.