GPU memory usage increases by ~90% after torch.load

I’m currently training a faster-rcnn model. Normal training consumes ~1900MiB of gpu memory. When I try to resume training from a checkpoint with torch.load, the model takes over 3000MiB. With identical settings specified in a config file.

I believe these are the relevant bits of code:

voc_dataset = PascalVOC(DATA_PATH, transform, LIMIT)
voc_loader = DataLoader(voc_dataset, shuffle=SHUFFLE, pin_memory=True)
basecnn = BaseCNN(ARCHITECTURE, REQUIRES_GRAD)
rpn = RegionProposalNetwork(batch_size=RPN_BATCH_SIZE)
detector = Detector(batch_size=CLF_BATCH_SIZE)

frcnn = FasterRCNN(basecnn, rpn, detector)
optimizer = Adam(filter(lambda p: p.requires_grad, frcnn.parameters()), lr=LEARNING_RATE)

if RESUME_PATH:
    experiment = torch.load(RESUME_PATH)
    frcnn.load_state_dict(experiment['model_state'])
    optimizer.load_state_dict(experiment['optimizer_state'])
    frcnn.basecnn.finetune()

frcnn.cuda()

Full training code is here: https://github.com/A-Jacobson/faster_rcnn/blob/master/train.py

3 Likes

Yes, I have experienced the same situation before. This is quiet annoying ! I will bump on OOM if I would like to resume training from a checkpoint.

Have you tried to save and load the state_dict? I’ve seen some issues where e.g. the training resumed when using the second approach explained here.

Yes. I always follow the best practice to save and load the state_dict.
I found a related issue here
It says torch.cuda.empty_cache()might help, but in my case I still have OOM.
By the way, I’m using pytorch 0.3.1

Good news ! I seems able to avoid OOM by the following loading strategy:

checkpoint = torch.load(ckpt_file, map_location=lambda storage, loc: storage)
net.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])  # pytorch 0.3.1 has a bug on this, it's fix in master
del checkpoint  # dereference seems crucial
torch.cuda.empty_cache()
5 Likes

That’s good!
Btw, shouldn’t the checkpoint be in CPU memory when loading with the map_location parameter?

Yes. With map_location=lambda storage, loc: storage, tensors in checkpoint are in CPU memory at first.
However, I guess load_state_dict may cast tensors to the corresponding device of model parameters internally, and the references to the casted tensors are still held by checkpoint. I don’t really trace it down.

I’ll have to check this on my end. Embarrassingly, I ended up getting a new gpu and the temporary spike didn’t matter all that much anymore.

Hi roytseng, I face the same problem when using 0.3.1, and it still blocking me. I have tried del+cuda.empty_cache() but it doesn’t work in my case. I notice your comment “pytorch 0.3.1 has a bug on this, it’s fix in master”, could you explain more on this? Links to the original issues or commits would be really helpful(i have checked commits about optimizers but do not find it). Thanks in advance!

Hi zjoe,

I not really remember if the bug in load_state_dict of optimizer is related to the memory usage increment or not (I guess it’s not). However, it’s sure that this bug has been fixed in pytorch0.4.

Hi all,
Have you been able to fix the problem?
I am experiencing the same problem and I am using pytorch 0.5, so it does not seem that the new versions solve the problem. I am using:

if load:
    checkpoint = torch.load('./model.ckpt')
    startEpoch = checkpoint['StartEpoch']
    model.load_state_dict(checkpoint['state_dict'])
    del checkpoint
    torch.cuda.empty_cache()

but the problem persists. The problem is located in the point

loss.backward

however, the model is the same; and I am not loading anything else than the model and the epoch number. I need to reduce the batch size which is very annoying.
Thanks,
Dani.

1 Like

have you solved it? I meet this too.

del checkpoint  # dereference seems crucial

Worked for me !
It seems to have saved around +500mb in my case!
I was able to save but not load in pytorch. The checkpoint definitely took up valuable valuable gpu memory.

torch.cuda.empty_cache() 

Helped clear around 300mb!

Thank you @roytseng !

This thread is really old, but I have to report this can still be an issue with pytorch 1.8.1. I had to lower model’s batch size from 64 to 40 to just be able to resume from a checkpoint before for example. That’s quite a performance hit actually. This method works beautifully and is still relevant, cheers.

Edit: I can also confirm map_location=cpu is also works.

2 Likes

Still an issue. Ran into this today on Pytorch 1.10.

Deleting checkpoint and emptying cache doesn’t work.

Loading model to cpu first works as pointed out by @realiti.