Truncating part of a computational graph

Hi,

My goal is to keep a computational graph depth over a recursive sequence constant.

Let’s assume I have a training data D_1, D_2, … and a model M, then my model behaves the following

M…(M(M(D_1) || D_2) || D_3)…

On the step that trains on D_3, I would like to have computational graph of M(M(M_1 || D_2) || D_3), where M_1 is the output of M(D_1) but with no graph attached to it. My first attempt was (simplified code)

while True:
    inputs, target = batch
    history, prediction = model(torch.cat((self.saved_history, inputs), dim=0))
    loss = F.mse_loss(prediction, target)

    del self.saved_history
    self.saved_history = history

but that keeps full computational history (I guess torch.cat already captured a reference to it, so deleting the saved_history here is really a no-op). What I really need to say is, cut the computational graph of history tensor where it points to saved_history tensor.

Detaching saved_history in model(torch.cat((self.saved_history.detach(), inputs) also does not achieve my goal, as the graph then only encompasses the very last iteration (not last 2 iterations which his my goal in this sample).

Is there a way to express it in pytorch?

Piotr