Hi, I got an out of memory error with the code:
model.to(device).train()
I was running with PyTorch 1.8.0 on a Ubuntu 20.04 machine, and used 2 Nvidia GTX 2080ti GPU for training. The CUDA version is 11.0. However, this error occured before DataParallel. The full traceback message is:
Traceback (most recent call last):
File "mains_single.py", line 547, in <module>
main()
File "mains_single.py", line 143, in main
train(args)
File "mains_single.py", line 203, in train
save_checkpoint(args, net, 1234, 1e6)
File "mains_single.py", line 542, in save_checkpoint
model.to(device).train()
File "/data_c/lzq0330/miniconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 673, in to return self._apply(convert)
File "/data_c/lzq0330/miniconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 387, in _apply
module._apply(fn)
File "/data_c/lzq0330/miniconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 387, in _apply
module._apply(fn)
File "/data_c/lzq0330/miniconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 409, in _apply
param_applied = fn(param)
File "/data_c/lzq0330/miniconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 671, in convert
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
RuntimeError: CUDA error: out of memory
Do anyone know a possible reason? The context of the code is:
save_checkpoint(args, net, 1234, 1e6)
def save_checkpoint(args, model, epoch, eval_loss):
device = torch.device("cuda" if args.cuda else "cpu")
model.eval().cpu()
checkpoint_model_dir = args.ckp_dir
if not os.path.exists(checkpoint_model_dir):
os.makedirs(checkpoint_model_dir)
ckpt_model_filename = args.dataset_name + "_" + args.model_title + "_ckpt_epoch_" + str(epoch) + ".pth"
ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename)
state = {"epoch": epoch, "model": model, "eval_loss": eval_loss}
torch.save(state, ckpt_model_path)
model.to(device).train()
print("Checkpoint saved to {}".format(ckpt_model_path))