Currently, when loading of the saved model is used as checkpoints, this code causes GPU RAM fragmentation:
torch.save(model.state_dict(), PATH)
model.load_state_dict(torch.load(PATH))
-
torch.load
loads the model on cuda allocating new RAM -
load_state_dict
replaces the old model with the new one, while freeing the previously allocated gpu RAM
The following memory allocation sketch shows how a hole is created in these 2 steps, which is very likely to lead to memory fragmentation.
1. [allocated mem][prev model][new model][free]
2. [allocated mem][---hole---][new model][free]
If the checkpoints are used multiple times and the model is ever-growing, there will be lots of free, but non-reusable sections in the memory. leading to a lot of wasted GPU RAM.
Is there a way to add an alternative re-loading function which first frees up the previous dict and then replaces it with the new one? Looking at the code it appears it’d be very complicated to implement. Unless perhaps it’d be enough to remove it from CUDA.
Another approach would be to have a torch.load() split into two parts, so that (1) the first part loading into general RAM, (2) then load_state_dict assigning it and (3) switching it to the right cuda device (I’m assuming cuda in this example).