As mentioned in this tutorial I currently have to provide a manual backward function in a cpp extension.
From my understanding, this backward would be the same as it would be evaluated by autograd in python (just implemented in C++) or am I mistaken here?
Is there any possibility to show the functions which are actually performed during autograd’s backward (to be sure I am not missing any part of the gradients calculation)?
The usual way to check on gradients is torch.autograd.gradcheck.
For autograd-traced calculations, you can traverse the graph by following .grad_fn and their .next_functions.
Are the functions in .grad_fn the ones which have been used during the forward pass or which have to be called to calculate the gradients? I.e. are they part of the forward or the backward path?
As far as I know (I think it came up a couple of weeks ago), you cannot call them and you cannot actually access the parameters (except for Python-defined torch.autograd.Function subclasses).
If you only use basic ops (as opposed to custom kernels), you could see whether you can use automatic differentiation (edit: I tried, works out of the box).
In that case, however, it might not be necessary to use C++ extension at all unless you do really funny stuff, the Python speed has been rather decent for me. (I tried that with torch.nn.functional.bilinear once.)
Now, all that is theorizing, I don’t know what is best for your use case.
I just looked at the tutorial again, I think part, possibly a large part, of the speed-up might well be from the forward not creating a graph, so a more fair C++ vs. Python comparison would be to use with torch.no_grad() in Python or implement forward and backward as an autograd function in Python.
Thanks a lot. So to sum this up: In your opinion it should be just fine to stay with python as long as I don’t want to create some fancy kernels (which I actually don’t) since the speedup would be minimal in a fair comparison?
Don’t take my word for it though. That might be a part. So if a, say, 10% speedup is worth writing it in C++, do try. But I think you don’t get 30% just from moving to C++, but a larger part of that is the custom backward (which you could do similarly in Python).