I am trying to train a network that would be used to run on three different images and thus generate three outputs. The resulting outputs are combined together in some way later to calculate the loss.
The issue is that since the model is quite huge, running it three times might cause OOM errors for GPU.
I was using
torch.no_grad() to inference two of the images and only allow the gradient backprop through the last image.
Is there a better way to do so?
I was told it’s better to let the gradient go through all three images.