It could be a useful feature to allow deletion of variable from the computational graph, when back-propagating before that point is no longer necessary to save space and allow for ‘persistent’ computational structures. Deletion means release memory held by earlier variables and set new leaf variables. A example use case could be running a RNN for infinite steps and but only need gradient for K steps back computed at every single step, and it feels like there should be a way to do this without the need of constructing a new graph every time when gradient is needed.
What’s the problem with constructing a new graph every time? That’s the whole framework philosophy, so that it’s not restricted to static graph. To cut the history you can repackage the Variable in a new one Variable(var.data)
(this is acceptable, as long as you discard var
immediately - we don’t support having many variables with the same data). Also, you can use .detach()
to obtain a new Variable
that doesn’t require the gradient.
Because it may avoid some redundant computations in some cases? Say if I want to compute gradient for every step of RNN while looking K steps back for the gradient, I may not want to construct a new graph from K steps back at every step because lots of the forward computations are already done before. Is this right?
Ok, so as far as I understand, you want to do k steps with your RNN, and then move that k-sized window one step forward, and compute backward for every move? So effectively (nearly) every path would get backproped through k times?
Yes. I guess there are other use cases where subpaths will be backproped through multiple times, and it will be great if we can keep the graph and don’t have to reconstruct the whole graph with every backward call. This deviates a little from my original question but if we allow the graph to persist and grow then it can come to a point that forgetting some history, possibly by allowing deleting variables and reassigning leaf variables, becomes necessary and thus the question I was asking.
No, we haven’t thought about this use case. Is it something you’re doing or just an example?
Should be straightforward to add though.
I don’t have a working model that works like this but I work with very long sequences that is not practical to backprop all the way through, and this is one way I would try. I do think such flexibility is going to be helpful for enabling new types of models and I hope you would agree. Thanks!
Yes it is. We’ve discussed it and will add it soon.
Great! You guys are awesome : )
Has this been implemented? I have currently an implementation in which it would be great to backpropagate every output of an LSTM for K steps backwards instead of detaching the hidden state every K steps. I’m not sure of how would updates to the parameters look like, but it would be interesting to explore e.g. for the char-rnn.
Hi. I have the same need. I want to do truncated backpropagation through time, and want to stop the gradient propagation past the hidden state from T time-steps ago. My previous approach was just to buffer the last T inputs and do a full computation of the forward pass from t-T to t, but that involves a lot of repeated computation!
I’m looking for a way that I can just take the hidden state variable from T steps ago, and tell torch “don’t backpropagate past this variable!”. Is there a workaround that allows this?
Hi. I have the same problem too, has it been implemented yet or has a workaround been found ?
I think it would be very interesting to implement the feature (like a maximum graph depth or something like that).
For now I’ll try to implement something to crop the graph, I’ll probably make a new conversation about this.
Check this solution:
@garibarba Thanks for that, I still find it to be quite slow (my crude benchmark tells that it is about 2x slower than native .backward()
)