I have a function that searches for the maximum batch size a model can have on a given GPU. The logs of this function look like this:
Batch size 1 succeeded. Increasing to 2...
Batch size 2 succeeded. Increasing to 4...
Batch size 4 succeeded. Increasing to 8...
Batch size 8 succeeded. Increasing to 16...
Batch size 16 succeeded. Increasing to 32...
Batch size 32 succeeded. Increasing to 64...
Batch size 64 succeeded. Increasing to 128...
Batch size 128 succeeded. Increasing to 256...
Batch size 256 succeeded. Increasing to 512...
Batch size 512 failed. Binary searching...
# We start with the bounds as (256 - 50) to 512 to detect the bug
# detailed later in this post
Batch size 359 failed. New bounds: [206, 359]
Batch size 282 failed. New bounds: [206, 282]
Batch size 244 failed. New bounds: [206, 244]
Batch size 225 failed. New bounds: [206, 225]
Batch size 215 failed. New bounds: [206, 215]
Batch size 210 failed. New bounds: [206, 210]
Batch size 208 failed. New bounds: [206, 208]
Batch size 207 failed. New bounds: [206, 207]
However, notice something odd about these logs. Initially, the batch size of 256 succeeds. However, later–when doing the binary search–we see that smaller batch sizes later fail.
I think this indicates some sort of bug on my end, where the GPU memory is not being reclaimed properly.
Right now, before each call to the function that does the forward/backward pass I call: torch.cuda.empty_cache(), but I think that might not be enough.
What else should I be doing to reset the GPU memory state?