Hello,
I’m running Llama 3.1 8B with custom operators + all reduce and the memory footprint is larger than expected. pytorch memory seems to be suggesting that it’s intermediate tensors that undergo all reduce are requested to be freed but does not get freed after the entire forward pass.
I implemented my custom kernel and operator following this guide.
Here’s a snapshot of dynamo graph that creates the intermediate tensor and performs all reduce.
# Topologically Sorted Source Nodes: [out_2], Original ATen: [cutlass_gemm._gemm]
buf11 = torch.ops.custom_gemm.gemm.default(arg8_1, reinterpret_tensor(buf9, (2048, 32768), (1, 2048), 0), 128, 192, 1, 1, False, True, False, 0, 1)
del arg8_1
del buf9
buf12 = buf11
del buf11
# Topologically Sorted Source Nodes: [tensor], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf12, (128, 256, 4096), (1048576, 4096, 1), 0), 'sum', '0')
# Topologically Sorted Source Nodes: [wait_tensor], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf12, (128, 256, 4096), (1048576, 4096, 1), 0))
buf18 = empty_strided_cuda((128, 256, 4096), (1048576, 4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [x, h, mul_11, mean_1, add_5, rsqrt_1, mul_12, mul_13], Original ATen: [aten.embedding, aten.add, aten.mul, aten.mean, aten.rsqrt]
triton_red_fused_add_embedding_mean_mul_rsqrt_4.run(arg3_1, arg2_1, buf12, arg9_1, buf18, 32768, 4096, grid=grid(32768), stream=stream1)
buf12
also gets deleted soon after so I was expecting the tensor created during custom_gemm.gemm
to be deleted afterwards.
This code from CUDACachingAllocator.h suggests that some other streams might be using the tensor, which I can’t find.
struct TraceEntry {
enum Action {
...
FREE_REQUESTED, // API call made to the caching allocator to free memory
FREE_COMPLETED, // The allocator might have to delay a free because
// it is still in use on another stream via record_stream
// This event is generated when a free actually completes.
...
};
Would it be possible to manually release a tensor? Or it looks like the tensor can be overwritten when I don’t use the custom kernel and set torch.compile config allow_buffer_reuse=True
. So would there be a way to let torch compiler overwrite buffers for custom kernels too?