Rerun a part of the computational graph

Is it possible to re-run just a part of a computational graph?

More details:
Suppose I have a network f with m inputs and n outputs, where each output does not necessarily depend on all inputs. The simplest example would be a triangular matrix, though typical networks I wish to consider are much more complex.
For an input x = (x_1, x_2… x_m), I can compute the output f(x) classically with a forward pass. Now, I would like to change just one of the inputs x_i, and recompute only the new j-th output f(x)_j, performing only useful computations, that is, not recomputing the nodes which still have the same inputs as before. As I need to perform many such operations (changing x coefficientwise), the difference in complexity (between running a full forward pass each time vs only the changed nodes) will be tremendous: quadratic vs linear.
Note that I do not care about backpropagation here, I need only forward passes.
So, questions:

  1. Is there a way to identify the subgraph (of the computational graph f) which contains only the nodes depending on x_i and having a consequence on f(x)_j ?

With access to the computational graph, this should be doable in two passes (one forward and one backward), to know, for any internal node, on which input (resp. output) variables it depends (resp. have an influence on), by propagating sets of indices of input (resp. output) variables along the graph. In practice, I do not know how to do this with PyTorch.

  1. Is there a way to re-run only this subgraph?

(possibly extracting / cloning it first if re-running in place is not feasible…)


I thought computational graph was more of a tensorflow concept… in which the graph is a fixed thing that your data “flows” through… where as in pytorch… it’s “just python” so you can breakpoint or use python logic to basically stop or stop the network(s)… or am I misunderstanding your question?

I suppose PyTorch does use a computational graph in autograd; but indeed the forward pass usually does not make use of it.
I could define by hand new .forward() functions that would do the job and compute only what is needed, if I know which nodes to compute and if the network is simple; but for complex networks, I would have wished to do that automatically: take a network as input, identify the desired subgraph, and run it; for this I need the computational graph of the network.
I guess I can have a look at the implementation of .backward() for inspiration, but that’s going to take a while…