Hi,
Is it possible to re-run just a part of a computational graph?
More details:
Suppose I have a network f with m inputs and n outputs, where each output does not necessarily depend on all inputs. The simplest example would be a triangular matrix, though typical networks I wish to consider are much more complex.
For an input x = (x_1, x_2… x_m), I can compute the output f(x) classically with a forward pass. Now, I would like to change just one of the inputs x_i, and recompute only the new j-th output f(x)_j, performing only useful computations, that is, not recomputing the nodes which still have the same inputs as before. As I need to perform many such operations (changing x coefficientwise), the difference in complexity (between running a full forward pass each time vs only the changed nodes) will be tremendous: quadratic vs linear.
Note that I do not care about backpropagation here, I need only forward passes.
So, questions:
- Is there a way to identify the subgraph (of the computational graph f) which contains only the nodes depending on x_i and having a consequence on f(x)_j ?
With access to the computational graph, this should be doable in two passes (one forward and one backward), to know, for any internal node, on which input (resp. output) variables it depends (resp. have an influence on), by propagating sets of indices of input (resp. output) variables along the graph. In practice, I do not know how to do this with PyTorch.
- Is there a way to re-run only this subgraph?
(possibly extracting / cloning it first if re-running in place is not feasible…)
Thanks!