Tips on an unusual form of pipeline parallelism

Hi there,
I want to run a scenario that features an unusual sort of pipeline parallelism. In short, the idea is this:

There are N devices, d1, d2, … dN and in each of those “resides” part of a neural network. This can be either a single layer, or a more complicated part, but in any case it’s expressed as a function f1, f2, … fN for each device respectively. Each f_i receives as input the output of the previous function/device. It also has a set of parameters. As such, f_i are instances of subclasses of torch.nn.Module.

The catch is this: each device maintains a buffer that should only hold the last input it processed. When the final device/function computes an output, the loss is computed and the gradient wrt. to the input of fN is sent to the previous device. Each device should then backpropagate the “pseudogradient” that results from computing its “local” derivative on the input it currently has in its buffer and combining it with the backpropagated gradient it received.

This is a kind of “delayed” backpropagation that aims on avoiding the storage of activations until a backward pass is computed.

It’s not entirely clear to me how something along these lines can be done in PyTorch. How can I make sure that each time an f_i runs on a new buffer input it overwrites the computational graph of its previous run? I’m fairly sure that I will need calls to detach() for this, in order to avoid backward passes propagating too “far back”. However, by themselves they don’t seem to give me all of the fine-grain control this needs.

Any help is appreciated!