Why GPU memory allocations are associated with the cuda stream?

From pytorch/c10/cuda/CUDACachingAllocator.cpp, I found this:

Allocations are associated with a stream. Once freed, blocks can be re-allocated on the same stream, but not on any other stream.

And in code, it actually follows this rule.

I can’t figure out why this limitation exists. As far as I know, CUDA doesn’t forbid different cuda streams to access the memory allocated by cudaMalloc. And this limitation will disallow some free blocks allocated by one cuda stream to be reused by other cuda streams.

Example:

import torch

a = torch.rand((12, 12)).cuda()   # allocated by default cuda stream
print(torch.cuda.memory_cached())

torch.cuda.synchronize()
s = torch.cuda.Stream()
with torch.cuda.stream(s):
    b = torch.rand((12, 12)).cuda()    # allocated by other cuda stream

torch.cuda.synchronize()
print(torch.cuda.memory_cached())

output is: 2097152 and 4194304.

import torch

a = torch.rand((12, 12)).cuda()   # allocated by default cuda stream
print(torch.cuda.memory_cached())

b = torch.rand((12, 12)).cuda()   # allocated by default cuda stream
print(torch.cuda.memory_cached())

output is: 2097152 and 2097152.

Environment: PyTorch 1.3, cuda 10.1

My understanding was that this reduces the number of necessary synchronizations.
cc @colesbury will have a better insight.

As @albanD wrote, limiting CUDA allocations to a single stream reduces the number of CPU-GPU synchronizations necessary. CUDA kernels are asynchronous, so when an allocation is “freed” the kernel may not be finished (or may not have even started). Reusing the same allocation in a different stream could cause memory corruption because work in that stream may start before previously launched work in the original stream finishes.

It’s safe to immediately reuse the allocation in the same stream because operations within a stream are ordered sequentially. This is the strategy the caching allocator uses.

The CUDA memory API handles this differently: The cudaFree call synchronize all streams – the CPU waits until all streams finish all outstanding work before the cudaFree call completes. This ensures that subsequent uses of the memory are guaranteed to happen after the previous uses finish. However, this makes cudaFree a relatively expensive call. The primary goal of the caching allocator is to avoid this type of synchronization.

5 Likes