Hello. I have a large network, which is to large for some batchsize. However it can be decouple to two sub module and has their loss for backprop respectively.
Therefore, I wonder whether there is a way to do like this:
subnet1.forward() loss1 = calc_loss(subnet1) loss1.back_ward() get the gradient to grad1 optimizer.zero_grad() subnet2.forward() loss2 = calc_loss(subnet2) loss2.back_ward() get the gradient to grad2 optimizer.zero_grad() collected_grad = grad1 + grad2 distribute the collected_gradto parameter
the key step is to retrieve the gradient, and then assign back.
Can PyTorch accomplishes that ?