Hi guys,
I just read some materials from Internet that the checkpoint in PyTorch could be utilized to decrease the usage of GPU memory, and I want to know if that is true?
Any answer or idea will be appreciated!
Hi guys,
I just read some materials from Internet that the checkpoint in PyTorch could be utilized to decrease the usage of GPU memory, and I want to know if that is true?
Any answer or idea will be appreciated!
Hi,
This is what this module was built for. So yes it reduces memory usage at the cost of using more compute.
Hi, Alban,
could you please share some tutorial of how to use checkpoint to reduce memory usage?
I don’t think there is much more than just the doc: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
It should be as simple as calling groups of layers from your network with the module.
You can find an advanced example in torchvision: https://github.com/pytorch/vision/blob/216035315185edec747dca8879d7197e7fb22c7d/torchvision/models/densenet.py
I am really grateful for your answer.
After I read some of your provided materials, I find that it seems using checkpoint to trade compute for memory only applies to ReLU functions, am I right?
Looking forward to your reply!
Hi,
No it applies to any module.
It allows you to reduce the number of intermediary results between ops. So it will only be useful if you checkpoint a function that contains multiple ops.
Thanks sincerely for your answer!