Hello. I have a large network, which is to large for some batchsize. However it can be decouple to two sub module and has their loss for backprop respectively.
Therefore, I wonder whether there is a way to do like this:
subnet1.forward()
loss1 = calc_loss(subnet1)
loss1.back_ward()
get the gradient to grad1
optimizer.zero_grad()
subnet2.forward()
loss2 = calc_loss(subnet2)
loss2.back_ward()
get the gradient to grad2
optimizer.zero_grad()
collected_grad = grad1 + grad2
distribute the collected_gradto parameter
the key step is to retrieve the gradient, and then assign back.
Can PyTorch accomplishes that ?
Thanks !