How to hack out the temp gradient to save gpu memory?

Hello. I have a large network, which is to large for some batchsize. However it can be decouple to two sub module and has their loss for backprop respectively.
Therefore, I wonder whether there is a way to do like this:

loss1 = calc_loss(subnet1)
get the gradient to grad1
loss2 = calc_loss(subnet2)
get the gradient to grad2
collected_grad = grad1 + grad2
distribute the collected_gradto parameter

the key step is to retrieve the gradient, and then assign back.
Can PyTorch accomplishes that ?
Thanks !

How are you planning on storing the gradient and reassigning it back?
If you just hold it on the GPU, the same amount of memory will be used.

I think you should check out torch.utils.checkpoint, which can be used to trade compute for memory.