Lazy execution in PyTorch?

Hi everyone.

I`d like to implement a custom backend for PT that would support lazy network execution: e.g. when ops get added they are not executed right away but rather accumulate the topology, and only when some trigger happens — e.g. a target tensor gets read — is the previously accumulated topology optimized, compiled and executed on a custom non-CPU device.

Is that at all possible in PT? I couldn`t find much documentation on how to write PT backends of the kind I just described.

Okay, now that I know more about PT I can answer my own question.

The solution is to use TorchScript and work directly with the traced graph.
For that you`ll need a torch::jit backend.
Here`s a useful article on the topic. Bear in mind however that it`s already quite outdated, so e.g. CustomFuseGraph would not be of much help so you`ll need to create fusion groups (i.e. meta-primitives you`ll be in charge of compiling and executing) by hand — substituting the existing prims via SubgraphUtils::createSingletonSubgraphAndUpdateAliasing() and SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing().