CUDA OOM when resume training

Hi, all

I’m quite headache now.

I have used 32cards to train 40epochs. While the training stops at 30th epoch.

I wanna resume training from the checkpoint. Although I successfully load the checkpoint, the training script declares OOM at the first glance of minibatch.

Here is my saving and loading scripts.

def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
    output_dir = Path(args.output_dir)
    epoch_name = str(epoch)
    if loss_scaler is not None:
        checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
        for checkpoint_path in checkpoint_paths:
            to_save = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'scaler': loss_scaler.state_dict(),
                'args': args,
            }
            save_on_master(to_save, checkpoint_path)

def load_model(args, model_without_ddp, optimizer=None, loss_scaler=None):
    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint['model'])
        if 'optimizer' in checkpoint:
            if optimizer is not None:
                optimizer.load_state_dict(checkpoint['optimizer'])

Here is my minimal instance. Here, the optimizer is cybertronai/pytorch-lamb

model = MyModelClass()
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
model_without_ddp = model.module

param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
optimizer = Lamb(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-4)

load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer)

I oberserve the phenomenon that, when training from scratch, all works well. But after loading the checkpoint, the training fails at the first backward.

I found [URL-hhttps://discuss.pytorch.org/t/gpu-memory-usage-increases-by-90-after-torch-load/9213/14] is related, but its method doesn’t work.

Any help will be quite appreciated.