How to hack out the temp gradient to save gpu memory?

Hello. I have a large network, which is to large for some batchsize. However it can be decouple to two sub module and has their loss for backprop respectively.
Therefore, I wonder whether there is a way to do like this:

subnet1.forward()
loss1 = calc_loss(subnet1)
loss1.back_ward()
get the gradient to grad1
optimizer.zero_grad()
subnet2.forward()
loss2 = calc_loss(subnet2)
loss2.back_ward()
get the gradient to grad2
optimizer.zero_grad()
collected_grad = grad1 + grad2
distribute the collected_gradto parameter

the key step is to retrieve the gradient, and then assign back.
Can PyTorch accomplishes that ?
Thanks !

How are you planning on storing the gradient and reassigning it back?
If you just hold it on the GPU, the same amount of memory will be used.

I think you should check out torch.utils.checkpoint, which can be used to trade compute for memory.