CUDA out of memory error with to operation

Hi, I got an out of memory error with the code:

model.to(device).train()

I was running with PyTorch 1.8.0 on a Ubuntu 20.04 machine, and used 2 Nvidia GTX 2080ti GPU for training. The CUDA version is 11.0. However, this error occured before DataParallel. The full traceback message is:

Traceback (most recent call last):
  File "mains_single.py", line 547, in <module>
    main()
  File "mains_single.py", line 143, in main
    train(args)
  File "mains_single.py", line 203, in train
    save_checkpoint(args, net, 1234, 1e6)
  File "mains_single.py", line 542, in save_checkpoint
    model.to(device).train()
  File "/data_c/lzq0330/miniconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 673, in to    return self._apply(convert)
  File "/data_c/lzq0330/miniconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 387, in _apply
    module._apply(fn)
  File "/data_c/lzq0330/miniconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 387, in _apply
    module._apply(fn)
  File "/data_c/lzq0330/miniconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 409, in _apply
    param_applied = fn(param)
  File "/data_c/lzq0330/miniconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 671, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
RuntimeError: CUDA error: out of memory

Do anyone know a possible reason? The context of the code is:

save_checkpoint(args, net, 1234, 1e6)
def save_checkpoint(args, model, epoch, eval_loss):
    device = torch.device("cuda" if args.cuda else "cpu")
    model.eval().cpu()
    checkpoint_model_dir = args.ckp_dir
    if not os.path.exists(checkpoint_model_dir):
        os.makedirs(checkpoint_model_dir)
    ckpt_model_filename = args.dataset_name + "_" + args.model_title + "_ckpt_epoch_" + str(epoch) + ".pth"
    ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename)
    state = {"epoch": epoch, "model": model, "eval_loss": eval_loss}
    torch.save(state, ckpt_model_path)
    model.to(device).train()
    print("Checkpoint saved to {}".format(ckpt_model_path))