Truncating backprop temporarily

Is it possible to accomplish the following (from Chainer) in PyTorch?

For reference, I’m trying to implement the “efficient” trust region optimisation from ACER: Trust region update. Whereas my code currently backprops through the entire computation graph of the policy gradient loss and is hence very expensive, the efficient version cuts the graph just before the final layer, then calculates gradients. However, the parents need to be rejoined, as eventually the new loss should be backpropped through the entire graph.

You could split it by default in the forward pass and manually pass the gradients of the last bit to the backward of what has been cut off when that is desired.

Best regards


1 Like