How to avoid fragmentation during load_state_dict(load(...)) when used with an old model obj?

Currently, when loading of the saved model is used as checkpoints, this code causes GPU RAM fragmentation:

torch.save(model.state_dict(), PATH)
model.load_state_dict(torch.load(PATH))
  1. torch.load loads the model on cuda allocating new RAM
  2. load_state_dict replaces the old model with the new one, while freeing the previously allocated gpu RAM

The following memory allocation sketch shows how a hole is created in these 2 steps, which is very likely to lead to memory fragmentation.

1. [allocated mem][prev model][new model][free]
2. [allocated mem][---hole---][new model][free]

If the checkpoints are used multiple times and the model is ever-growing, there will be lots of free, but non-reusable sections in the memory. leading to a lot of wasted GPU RAM.

Is there a way to add an alternative re-loading function which first frees up the previous dict and then replaces it with the new one? Looking at the code it appears it’d be very complicated to implement. Unless perhaps it’d be enough to remove it from CUDA.

Another approach would be to have a torch.load() split into two parts, so that (1) the first part loading into general RAM, (2) then load_state_dict assigning it and (3) switching it to the right cuda device (I’m assuming cuda in this example).

1 Like

So the understanding that came from working on GPU RAM fragmentation diagnostics is that this won’t cause an actual fragmentation, since an average model will be 100MB+ and a whole bunch of whole gpu memory pages will be freed and re-used later through remapping of free pages.

So the only issue here is if there is not enough memory left to load the new model, without first unloading the old one, in which case the card will not be able to do much work anyway, other than perhaps a very simple inference. So I suppose this is a very low priority for devs to spend their time on.

All is good then.

Thanks again to @colesbury for his very insightful answer.