Save Memory for Multiple Forward Passes?


I am trying to train a network that would be used to run on three different images and thus generate three outputs. The resulting outputs are combined together in some way later to calculate the loss.
The issue is that since the model is quite huge, running it three times might cause OOM errors for GPU.
I was using torch.no_grad() to inference two of the images and only allow the gradient backprop through the last image.
Is there a better way to do so?
I was told it’s better to let the gradient go through all three images.

If you have very big memory issues, you can use the checkpoint module to reduce the memory usage (at the cost of more compute).

Thank you.
Another question, is the issue mentioned in this repo fixed in 1.3?

Essentially, the checkpoint function would be slower on multi-GPUs when using nn.DataParallel.

I am not familiar with that repo. You might want to double check. I am not sure what the difference between the implementations is.