I am training a rather large model with pytorch (Estimated forward/backward around 10 GB), and I am facing the problem that the training sometimes crashes at around ~100/800Epochs (computer shuts down). I am aware there can be multiple reasons like the GPU overheating, CPU getting trashed etc. So i was wondering what the best metrics would be to track down the issue, and if possible how to track these metrics efficently e.g. what liabries/commands to use.
For instance tracking the cpu usage is rather hard, since it is mostly used when loading images with the dataloader, so i am not sure where to call it. Looking forward for ideas!