OOM when resuming training

Hello there, i saved a state dict model along the optimizer state, but when i try to resume it , it seems
that the model and its loaded state consume more memory (about 2 GiB more) than the original one when training, which looks stupid, so i just want to ask if it has to be something else than a very bad loading code !
Basically the end of my original model training pipeline looks like that:

    nn.Conv3d(1, 16, 3, 1, 1), 
    nn.BatchNorm3d(16),
    nn.Conv3d(16, 32, 3, 1, 1),
    nn.Flatten(start_dim=1),
    nn.LazyLinear(128), nn.GELU(),    
    nn.LazyLinear(1), nn.Sigmoid()
).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

and the resume pipeline:

model = nn.Sequential(
    nn.Conv3d(1, 16, 3, 1, 1), 
    nn.BatchNorm3d(16),
    nn.Conv3d(16, 32, 3, 1, 1),
    nn.Flatten(start_dim=1),
    nn.LazyLinear(128), nn.GELU(),    
    nn.LazyLinear(1), nn.Sigmoid()
)
checkpoint = torch.load('blabla.pt')
model.load_state_dict(checkpoint['model'])
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
optimizer.load_state_dict(checkpoint['optimizer'])

Well i should also say that i train and resume on different GPU (Cloud) with the same size but not the same appellation (liek A4000 and P5000), could it really be the matter (a difference of 2GiB )?
(Tell me if the question isn(t understandable bad english speaker here :slight_smile: )

And thanks in advance !

Are you using the same PyTorch environment in both setups (the working and failing one)?
Starting with the PyTorch release supporting CUDA 11.7 we enabled CUDA’s lazy module loading, which would reduce the context size and lazily load only needed kernels into the context.
2GB sounds like too much, but might be related.
Also, do you know if 2GB relates to any specific object, e.g. the model’s parameter size, the optimizer state size etc.?

1 Like

Sorry i thought i deleted this post since my post got hidden by the bot and by the time I just reduced vram utilization to dodge the error!

It was the same Pytorch environment and same vram capacity but as i mentionned not the same GPU,
if i remember correctly, on the first one the VRAM utilization was at, at least 98% during training.
Thank you anyway :slight_smile: !