Increased memory footprint with custom kernel and all reduce

Hello,
I’m running Llama 3.1 8B with custom operators + all reduce and the memory footprint is larger than expected. pytorch memory seems to be suggesting that it’s intermediate tensors that undergo all reduce are requested to be freed but does not get freed after the entire forward pass.

I implemented my custom kernel and operator following this guide.

Here’s a snapshot of dynamo graph that creates the intermediate tensor and performs all reduce.

        # Topologically Sorted Source Nodes: [out_2], Original ATen: [cutlass_gemm._gemm]
        buf11 = torch.ops.custom_gemm.gemm.default(arg8_1, reinterpret_tensor(buf9, (2048, 32768), (1, 2048), 0), 128, 192, 1, 1, False, True, False, 0, 1)
        del arg8_1
        del buf9
        buf12 = buf11
        del buf11
        # Topologically Sorted Source Nodes: [tensor], Original ATen: [_c10d_functional.all_reduce]
        torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf12, (128, 256, 4096), (1048576, 4096, 1), 0), 'sum', '0')
        # Topologically Sorted Source Nodes: [wait_tensor], Original ATen: [_c10d_functional.wait_tensor]
        torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf12, (128, 256, 4096), (1048576, 4096, 1), 0))
        buf18 = empty_strided_cuda((128, 256, 4096), (1048576, 4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [x, h, mul_11, mean_1, add_5, rsqrt_1, mul_12, mul_13], Original ATen: [aten.embedding, aten.add, aten.mul, aten.mean, aten.rsqrt]
        triton_red_fused_add_embedding_mean_mul_rsqrt_4.run(arg3_1, arg2_1, buf12, arg9_1, buf18, 32768, 4096, grid=grid(32768), stream=stream1)

buf12 also gets deleted soon after so I was expecting the tensor created during custom_gemm.gemm to be deleted afterwards.

This code from CUDACachingAllocator.h suggests that some other streams might be using the tensor, which I can’t find.

struct TraceEntry {
  enum Action {
    ...
    FREE_REQUESTED, // API call made to the caching allocator to free memory
    FREE_COMPLETED, // The allocator might have to delay a free because
                    // it is still in use on another stream via record_stream
                    // This event is generated when a free actually completes.
    ...
  };

Would it be possible to manually release a tensor? Or it looks like the tensor can be overwritten when I don’t use the custom kernel and set torch.compile config allow_buffer_reuse=True. So would there be a way to let torch compiler overwrite buffers for custom kernels too?

It’s a bit hard to tell from that standalone snippet of inductor output, but it sounds like you think torch.compile (inductor) is keep the output of your custom gemm alive longer than it needs to (given that I don’t see del buf12 anywhere in that region).

A few things that would help more:

(1) If you can run your entire repro with TORCH_TRACE=some_tmp_directory python your_repro.py, and then zip it up / include it, then we’ll have a bit more helpful info (we can look at both the graph that torch.compile captured, as well as the entire inductor output code, to see why inductor chose not to free the buffer sooner). If you’re interested, more info about TORCH_TRACE in Ed’s podcast: https://pytorch-dev-podcast.simplecast.com/episodes/torch-trace-and-tlparse

(2) even better, if you’re able create a self-contained repro script that someone can run

(3) if you are seeing a higher memory footprint compared to eager mode, this is definitely a bug, and you’ll get more help filing a github issue: Sign in to GitHub · GitHub

1 Like

Thanks for the suggestions! I have the problem resolved.

It was becuase I was using legacy all reduce which the inductor converts it to functional_collectives.all_reduce_inplace. This inplace all reduce was generating memory leaks.

1 Like