Who has authored torch/csrc/autograd/python_functions.cpp or could answer some questions about it?

I am extending pytorch for complex numbers and I’m running into RuntimeErrors because I don’t understand how the autograd system plays together with Function classes. Who could answer a few questions about this?

Thanks.

ask away, the dev team can answer, we are looking at the forums.

Ok.

  1. When I’m writing a class that is inheriting from Function, there is some wrapping going on in C code before the python code inside forward() and backward() is executed. What is the convention

a) for passing arguments to Functions? Always pass Variables?
b) inside forward() and backward(), are certain input types and return types expected? Like, Variables are always unwrapped to tensors before forward(), always return a Variable or always return tensor in forward()?
c) for using @staticmethod for forward() and backward()?

  1. Is there any code that is tensor type specific inside autograd that could cause problems when introducing a new tensor type?
  2. Whats the convention for calling Functions? Function.apply(args) or Function().foward(args) or something else? I’m getting different errors for those.

Cheers,
kinda stuck right now.

Why does backward() work with Variables and forward with tensors? Why does backward not unwrap the Variables like forward?

a) No. In the new-style function format (using staticmethods and ctx) you can give arbitrary arguments to functions
b) Variables (but only those given directly, those in lists won’t be unwrapped) are unwrapped, all other arguments are passes as they are. forward() works only on tensors and should only return tensors. backward() works only on Variables and should return only Variables.
c) I don’t think I understand the question
d) backward() works with Variables so that it can be differentiated itself (for higher order derivatives). You can always use the @once_differentiable decorator to make it unwrap tensors, but raise an error if you try to differentiate twice.

I don’t think we have any code specific to tensor types.
The convention is Function.apply(*args).

Hope this helps

Thanks, that helped.