Is it possible to do forward() and backward() more memory efficiently? For example, During the chain rule, we only move part of the computation graph to cuda device that need to be back propagated. I think the overhead here is only .to(device).
May I know what do you mean by memory efficient?
As far as I know, Pytorch only stores the intermediate outputs for whatever weights, the gradient are needed (
requires_grad=True). In one way, this seems to be doing what you ask for.
If I misunderstood your question incorrectly, can you explain with a small code, if possible?
What I mean is to move part of those intermdediate outputs to CPU RAM and only use them when backprogation reaching them.
z = f(x), y=g(z), after I get z, I move intermediate outputs of f(x) to CPU RAM and then start to compute g(z). Same idea for backpropagation but reversely.
This would not work good, since this required CPU and GPU to synchronize quite often(currently they are asynchronous) and synchronization slows down a lot. The negative effect of
.to() for every intermediate output would probably be superior to the benefits of memory usage.