[Newb] Is there a way to step into Variable._execution_engine.run_backward()

Hi

I am trying to understand/instrument autograd in PyTorch. I’ve put a pdb.set_trace() before backward and traced the code until Variable._execution_engine.run_backward() is called which prevents pdb to step into. I presume this is where the code calls C++ extensions.

If so, is there a way to continue debugging and step into this function?

Thanks.

Hi,

This function and the whole autograd engine is implemented in cpp.
You will need a cpp debugger if you really want to step into it. It might not be very helpful though depending on what you’re trying to achieve?

You are right. Following yours and others advice in this thread, I cloned PyTorch’s source, built it with debugging flags, and set a pdb_trace point before backward. However I don’t actually know what happens after Variable._execution_engine.run_backward(), therefore I can’t put a breakpoint using gdb on called C++ function, so that it would stop when it reaches the underlying code.

As for my purpose, while I know autograd’s high level functionality, I am trying to study/tweak/understand how it is implemented.

The cpp engine is based on Function which are similar to the python ones. They are the elementary operations that are considered.
The forward pass attaches a grad_fn to the Tensors during the forward pass. Then you have a graph of Functions stored in cpp. Accessed from the next_functions field.

When computing a backward, the engine traverses this graph. The entry point in cpp is here and you can look in the same file for all the code.

Where can I find the implementaion for add-edge, sub-edge, multiple-edge and so on ??

Hi,

Where did you see these functions?

I didnt see these function, but I know pytorch can do any differential operator, so it might be some files that implement these differentiation

I’m not sure what you expect these functions to be / do?
There are no functions to manipulate the graph explicitly. It is built when performing operations but cannot be changed.

1 Like

The file you told me is actually what I expected to find, but in cpp files.

what is like …
y=x**4
and we can get gradient of y wrt x as 4x^3
so the code must return something like nx^(n-1)

import torch

x = torch.autograd.Variable(torch.Tensor([2]),requires_grad=True)
y = 5*x**2
y.backward()
x.grad

we get 20 because 5x^2 -> 10x = 10 * 2 = 20
And I would like to know the code inside backward() , there might be some function return 10 * input , right?

Hi,

This is actually never built explicitly.
What happens is that when the functions of the forward are seen, they corresponding backward ops are recorded.
And so if you say y = f( g(x) ) where g is the square function and f is time 5, noting z = g(x), what the backward computes is dy/dx as dy/dz * dz/dx = df(z)/dz * dg(x)/dx = 5 * 2x. Note that the 2x is computed first by the backward of g then 5 times that result for the backward of f.

Basically pytorch is just doing backprop, each elementary functions one at a time. Nothing more fancy.

1 Like

g(x)=x^2 ,but where pytorch compute dg(x)/dx as 2x or any those elementary function backward.
This is what I want to find out.
Also, thanks you for answer my question!!

You can find here the lines that define tha backward of the pow functions as the pow_backward functions.
If you look for this function, you will find it here with a slightly more general formula for any power.
Also remember that these always compute the backward pass. So if we have o = f(i)and then o is used to compute some loss L. it computes dL/di = dL/do * do/di. dL/do is given in the grad argument of the function.

1 Like

This -> templates/Functions.cpp is what I want !!!
THX A LOT :grinning: