Backwards pass using Pytorch/XLA

Hello, sorry if this is repeat question but how does Pytorch autograd engine fit with dispatching operations to the XLA backend. I understand that functions will save outputs in forward passes to use in the backwards pass but I don’t think the XLA backend makes use of these functions or sets the grad_fn


The autograd should work just the same as with regular CPU Tensors.
Do you have a code sample where there is discrepancy?

No discrepancy, I was just trying to get a better understanding of how the lazy execution flow works with backprop.

My basic understanding is that the lazy execution is completely hidden from pytorch by the xla binding.
So pytorch thinks it is synchronous. And only when we access the values on the CPU that xla actually compute stuff.

So it should have no impact :smiley: