Tracing backward pass in v1.1.0

Hello All,

My team at Microsoft Research India has built an adaptive compiler framework for deep learning jobs and implemented it on PyTorch (ASPLOS '19 paper). We used the tracing API in v0.3 to capture the compute graph representation of the network and its gradient computation, whose execution was then optimized by our framework.

We would like to port the implementation to v1.1.0, which entails capturing the compute graph through the tracing API. Since the backward pass doesn’t return any outputs, it’s not possible to use v1.1.0’s tracing API to build the computational graph for the backward pass. I tried the graph visualization package to build the computational graph for the backward pass, but the nodes were named “SelectBackward” or “ViewBackward”.

Could someone please help me with understanding the tracing API so that I can capture the graph for the backward pass as well ? Or please suggest a different methodology ?

Thank You,
Sanjay

2 Likes

Hi @singam-sanjay, could you give an example about how it worked in 0.3 and what you need from 1.1.0?
“SelectBackward” are the autograd nodes and I’m not sure if they’re expected to show up in your backward pass. (Given that you capture the graph representation and do optimization through your compiler framework.)
Happy to help if you can give more context. Thanks!

Hi @ailzhang,

Thanks for responding !

We used the tracing API in 0.3 in the following way,

  1. Tracing the forward pass,
trace, fwd_outputs = torch.jit.trace(model, fwd_args, nderivs = 1)
  1. Tracing the backward pass,
torch.autograd.backward(fwd_outputs, bwd_args)

After Step 2, the trace.graph() is updated with nodes from the backward pass.

Could you please suggest an approach to generate equivalent results in the v1.1.0 API ?

Thanks,
Sanjay

This is very interesting!
I have a somewhat similar question along these lines. If anyone feels I should address this in a different issue, I can create a new one.
Is there a way to “intercept” the backward pass on-the-fly, as it runs in a normal fashion? Specifically for pytorch/xla package, but I know ptxla doesn’t have this, so maybe it’s in autograd? The reason being, i would very much like to be able to tell which forward op each backward op is associated with (i.e. this subgraph is, for the most part, the bwd pass of the dropout layer, for instance) and this be able to keep track of this in some sidebanded fashion to be processed during graph compile?