How to reset GPU memory usage when figuring out max batch size?

I have a function that searches for the maximum batch size a model can have on a given GPU. The logs of this function look like this:

Batch size 1 succeeded. Increasing to 2...
Batch size 2 succeeded. Increasing to 4...
Batch size 4 succeeded. Increasing to 8...
Batch size 8 succeeded. Increasing to 16...
Batch size 16 succeeded. Increasing to 32...
Batch size 32 succeeded. Increasing to 64...
Batch size 64 succeeded. Increasing to 128...
Batch size 128 succeeded. Increasing to 256...
Batch size 256 succeeded. Increasing to 512...
Batch size 512 failed. Binary searching...
# We start with the bounds as (256 - 50) to 512 to detect the bug
# detailed later in this post
Batch size 359 failed. New bounds: [206, 359]
Batch size 282 failed. New bounds: [206, 282]
Batch size 244 failed. New bounds: [206, 244]
Batch size 225 failed. New bounds: [206, 225]
Batch size 215 failed. New bounds: [206, 215]
Batch size 210 failed. New bounds: [206, 210]
Batch size 208 failed. New bounds: [206, 208]
Batch size 207 failed. New bounds: [206, 207]

However, notice something odd about these logs. Initially, the batch size of 256 succeeds. However, later–when doing the binary search–we see that smaller batch sizes later fail.

I think this indicates some sort of bug on my end, where the GPU memory is not being reclaimed properly.

Right now, before each call to the function that does the forward/backward pass I call: torch.cuda.empty_cache(), but I think that might not be enough.

What else should I be doing to reset the GPU memory state?

Give us your train code snippet. torch.cuda.empty_cache() releases the unoccupied cached memory but that’s all. We should know the occupied one.

This is my code snippet:

def binary_search_batch_size(cfg: Settings):
    def _is_cuda_oom(e: RuntimeError):
        return 'CUDA out of memory' in str(e)

    cfg = Settings.parse_obj(cfg)
    model = ModelWrapper(model=cfg.model.model(), embed_dim=cfg.model.output_dim).cuda()
    optimizer = cfg.optimizer.create_optimizer(model)

    def run():
        with torch.cuda.amp.autocast(enabled=True):
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            run_batch(model, optimizer, shape)

    batch_size = 1
    max_batch_size = 1
    while True:
        try:
            batch_size = max_batch_size
            shape = (batch_size, 3, 224, 224)
            run()
            max_batch_size *= 2
            print(f"Batch size {batch_size} succeeded. Increasing to {max_batch_size}...")
        except RuntimeError as e:
            if not _is_cuda_oom(e):
                raise e
            print(f"Batch size {batch_size} failed. Binary searching...")
            # the 50 acts as a bullshit check to make sure we haven't regressed somehow
            low = batch_size // 2 - 50
            high = batch_size
            while low + 1 < high:
                batch_size = (low + high) // 2
                shape = (batch_size, 3, 224, 224)
                try:
                    run()
                    low = batch_size
                    print(f"Batch size {batch_size} succeeded. New bounds: [{low}, {high}]")
                except RuntimeError as e:
                    if not _is_cuda_oom(e):
                        raise e
                    high = batch_size
                    print(f"Batch size {batch_size} failed. New bounds: [{low}, {high}]")
            max_batch_size = low
            break
    return max_batch_size

def run_batch(model, optimizer, shape):
    batch = dict(
        aug1=torch.randint(0, 256, shape, dtype=torch.uint8).cuda(),
        aug2=torch.randint(0, 256, shape, dtype=torch.uint8).cuda(),
    )
    
    image_logits1, image_logits2 = model(batch)
    loss_val = loss(image_logits1, image_logits2, 0)
    optimizer.zero_grad()
    loss_val.backward()
    optimizer.step()