Workarounds for eval_frame.py potentially leaking dram?

I have some inference code for doing batched inference with a model similar to Whisper. I’m using queues to manage batches (which I’ve omitted from below), however the main logic is as follows:

def decode_token(
    model: AmtEncoderDecoder,
    x: torch.Tensor,
    xa: torch.Tensor,
    x_input_pos: torch.Tensor,
    xa_input_pos: torch.Tensor,
):
    logits = model.decoder.forward(
        x=x,
        xa=xa,
        x_input_pos=x_input_pos,
        xa_input_pos=xa_input_pos,
    )[:, -1]
    next_tok_ids = torch.argmax(logits, dim=-1)

    return logits, next_tok_ids


@optional_bf16_autocast
@torch.no_grad()
def process_segments(
    tasks: list,
    model: AmtEncoderDecoder,
    audio_transform: AudioTransform,
    tokenizer: AmtTokenizer,
    logger: logging.Logger,
):
    audio_segs = torch.stack(
        [audio_seg for (audio_seg, prefix), _ in tasks]
    ).cuda()
    log_mels = audio_transform.log_mel(audio_segs)
    audio_features = model.encoder(xa=log_mels)

    raw_prefixes = [prefix for (audio_seg, prefix), _ in tasks]
    prefix_lens = torch.tensor(
        [len(prefix) for prefix in raw_prefixes], dtype=torch.int
    )
    min_prefix_len = min(prefix_lens).item()
    prefixes = [
        tokenizer.trunc_seq(prefix, MAX_BLOCK_LEN) for prefix in raw_prefixes
    ]
    seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda()
    eos_idxs = torch.tensor([MAX_BLOCK_LEN for _ in prefixes], dtype=torch.int)

    for idx in (
        pbar := tqdm(
            range(min_prefix_len, MAX_BLOCK_LEN - 1),
            total=MAX_BLOCK_LEN - (min_prefix_len + 1),
            leave=False,
        )
    ):
        # for idx in range(min_prefix_len, MAX_BLOCK_LEN - 1):
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_mem_efficient=False, enable_math=True
        ):
            if idx == min_prefix_len:
                logits, next_tok_ids = decode_token(
                    model,
                    x=seq[:, :idx],
                    xa=audio_features,
                    x_input_pos=torch.arange(0, idx, device=seq.device),
                    xa_input_pos=torch.arange(
                        0, audio_features.shape[1], device=seq.device
                    ),
                )
            else:
                logits, next_tok_ids = decode_token(
                    model,
                    x=seq[:, idx - 1 : idx],
                    xa=audio_features,
                    x_input_pos=torch.tensor(
                        [idx - 1], device=seq.device, dtype=torch.int
                    ),
                    xa_input_pos=torch.tensor(
                        [], device=seq.device, dtype=torch.int
                    ),
                )

        update_seq_end_idxs_(
            next_tok_ids=next_tok_ids,
            seq=seq,
            eos_idxs=eos_idxs,
            prefix_lens=prefix_lens,
            idx=idx,
        )

        if all(_idx <= idx for _idx in eos_idxs):
            break

    # If there is a context length overflow, we need to have some special logic
    # to make sure that a sequence of the correct format is returned. Right now
    # it messes things up somehow
    if not all(_idx <= idx for _idx in eos_idxs):
        logger.warning("Context length overflow when transcribing segment")

    results = [
        tokenizer.decode(seq[_idx, : eos_idxs[_idx] + 1])
        for _idx in range(seq.shape[0])
    ]

    return results


def gpu_manager(
    gpu_batch_queue: Queue,
    result_queue: Queue,
    model: AmtEncoderDecoder,
    batch_size: int,
    gpu_id: int | None = None,
):
    logger = _setup_logger()
    tracemalloc.start()

    logger.info("Started GPU manager")

    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    global decode_token, recalculate_tok_ids
    model.decoder.setup_cache(batch_size=batch_size, max_seq_len=MAX_BLOCK_LEN)
    model.cuda()
    model.eval()
    decode_token = torch.compile(
        decode_token,
        mode="reduce-overhead",
        # mode="max-autotune",
        fullgraph=True,
    )

    audio_transform = AudioTransform().cuda()
    tokenizer = AmtTokenizer(return_tensors=True)

    # Before starting the loop
    _snapshot = tracemalloc.take_snapshot()

    cnt = 0
    try:
        while True:
            try:
                batch = gpu_batch_queue.get(timeout=10)
            except Exception as e:
                logger.info(f"GPU timed out waiting for batch")
                break
            else:
                try:
                    results = process_segments(
                        tasks=batch,
                        model=model,
                        audio_transform=audio_transform,
                        tokenizer=tokenizer,
                        logger=logger,
                    )
                except Exception as e:
                    logger.error(
                        f"Failed to process batch: {traceback.format_exc()}"
                    )
                    raise e
                else:

                    # PROFILING CODE
                    cnt += 1
                    if cnt % 50 == 49:
                        snapshot = tracemalloc.take_snapshot()
                        top_stats = snapshot.compare_to(_snapshot, "lineno")
                        print("--------")
                        for stat in top_stats[:4]:
                            print(stat)
                        _snapshot = snapshot

                    # pid = -1 when its a pad sequence
                    for result, (_, pid) in zip(results, batch):
                        if pid != -1:
                            result_queue.put({"result": result, "pid": pid})

    except Exception as e:
        logger.error(f"GPU manager failed with exception: {e}")
    finally:
        logger.info(f"GPU manager terminated")

Using the included memory profiling code (DRAM) I’m seeing the following increase in memory every 50 iterations:

<frozen importlib._bootstrap_external>:729: size=11.4 MiB (+11.2 MiB), count=89245 (+88290), average=134 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/linecache.py:137: size=4236 KiB (+4236 KiB), count=41961 (+41961), average=103 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/codecache.py:1884: size=4228 KiB (+4228 KiB), count=1991 (+1991), average=2175 B
--------
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:489: size=98.5 MiB (+48.4 MiB), count=3689044 (+1813686), average=28 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/tracemalloc.py:129: size=1220 KiB (+1220 KiB), count=17350 (+17350), average=72 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/tracemalloc.py:498: size=826 KiB (+826 KiB), count=17623 (+17623), average=48 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/tracemalloc.py:193: size=826 KiB (+826 KiB), count=17623 (+17623), average=48 B
--------
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:489: size=149 MiB (+50.1 MiB), count=5566156 (+1877112), average=28 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/tracemalloc.py:129: size=1440 B (-1219 KiB), count=20 (-17330), average=72 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/tracemalloc.py:125: size=1236 KiB (+1218 KiB), count=17581 (+17328), average=72 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py:1465: size=486 KiB (+314 KiB), count=5656 (+3656), average=88 B
--------
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:489: size=203 MiB (+54.1 MiB), count=7590301 (+2024145), average=28 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py:1465: size=616 KiB (+130 KiB), count=7166 (+1510), average=88 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py:1468: size=560 KiB (+118 KiB), count=7166 (+1509), average=80 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py:809: size=388 KiB (+81.2 KiB), count=7089 (+1485), average=56 B
--------
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:489: size=264 MiB (+61.2 MiB), count=9882844 (+2292543), average=28 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py:1465: size=832 KiB (+216 KiB), count=9676 (+2510), average=88 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py:1468: size=756 KiB (+196 KiB), count=9677 (+2511), average=80 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py:809: size=523 KiB (+135 KiB), count=9564 (+2475), average=56 B
--------
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:489: size=322 MiB (+58.5 MiB), count=12072808 (+2189964), average=28 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py:1465: size=1074 KiB (+243 KiB), count=12502 (+2826), average=88 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py:1468: size=977 KiB (+221 KiB), count=12502 (+2825), average=80 B
/home/loubb/miniconda3/envs/amt/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py:809: size=676 KiB (+153 KiB), count=12369 (+2805), average=56 B

This essentially keeps occurring (1mb+ per itt) until OOM . I’ve tried googling memory issues related to eval_frame.py and nothing has come up. I’ve been trying to debug this for a few days, including clearning the various caches ect… but nothing had worked. My current solution is to kill the process when memory consumption gets to a certain point, and manually restart it.

It should be noted that this only occurs when using torch.compile (I’m pretty sure this is the case). Does anyone know a workaround for this sort of problem?

1 Like