Loading a model multiple times

I am using the following format in a training procedure.
output1= model(input, mode=‘a’)
loss1= loss_fun1(output1)
output2=model(input, mode=‘b’)
total_loss = loss1+loss2

Now, the problem is I am getting a memory error when I am calling the model twice. Again, if use torch.no_grad() for the mode ‘b’, I am breaking the computational graph and, I cannot do that. Is there any solution to the problem?

Your current approach keeps the intermediate forward activations for both forward passes, which are needed to compute the gradients. You could try to lower the batch size, if possible, or trade compute for memory via torch.utils.checkpoint.