Is there any way that we can automatically extract a computation graph that represents the forward pass in a series of pytorch computations? We would like the nodes in the graph to have access to the functions that make up the forward pass, so that we can call them individually. We would also like to have access to the values of intermediate outputs of each function.
This would be something different from the computation graph constructed by
autograd, which generally only represents the backward pass. Its nodes only contain a reference to a
grad_fn, which is the derivative of the forward function.
For instance, if we provide the following code:
import torch x = torch.randn(10) M = torch.randn(10, 10) h = torch.matmul(M, x) h = torch.add(h, x) h = torch.tanh(h) y = torch.sum(h)
We would like to automatically extract a graph structure like
y | [sum] | [tanh] | [add] / \ | [matmul] | / \ x M
Where each bracketed node could be an object (call
Node, for instance) that stores a reference to its corresponding function as an attribute, and stores the output of the function during the forward pass.
It would be great if anyone could point to any pytorch APIs that enables some kind of functionality like this!