Hi,
My goal is to keep a computational graph depth over a recursive sequence constant.
Let’s assume I have a training data D_1, D_2, … and a model M, then my model behaves the following
M…(M(M(D_1) || D_2) || D_3)…
On the step that trains on D_3, I would like to have computational graph of M(M(M_1 || D_2) || D_3), where M_1 is the output of M(D_1) but with no graph attached to it. My first attempt was (simplified code)
while True:
inputs, target = batch
history, prediction = model(torch.cat((self.saved_history, inputs), dim=0))
loss = F.mse_loss(prediction, target)
del self.saved_history
self.saved_history = history
but that keeps full computational history (I guess torch.cat already captured a reference to it, so deleting the saved_history here is really a no-op). What I really need to say is, cut the computational graph of history tensor where it points to saved_history tensor.
Detaching saved_history in model(torch.cat((self.saved_history.detach(), inputs) also does not achieve my goal, as the graph then only encompasses the very last iteration (not last 2 iterations which his my goal in this sample).
Is there a way to express it in pytorch?
Piotr