How to maximize CPU <==> GPU memory transfer speeds?

Hi, I am looking into different ways to optimize the running speed of my code, and one of these is looking at the speed of memory transfers between CPU and GPU, and the performances that I have measured do not seem to match up to the hardware’s theoretical one. I have written the following script:

(note: I decided to re-use the same pinned memory buffer, in order to avoid the overhead from re-allocating it over and over again)

import argparse
import time

import torch
from tqdm import trange


def stress_vram_transfer(
        batch_size=10,
        warmup=5,
        repeats=100,
        frame_shape=(3, 3840, 2160),
        use_pinned_memory=True,
):
    tensor = torch.randn((batch_size, *frame_shape))
    if use_pinned_memory:
        tensor = tensor.pin_memory()
    in_loop_tensor = tensor

    for device_id in range(torch.cuda.device_count()):
        print(f"Starting test for device {device_id}: {torch.cuda.get_device_properties(device_id)}")
        for _ in trange(warmup, desc="warmup"):
            in_loop_tensor = in_loop_tensor.cuda()
            if use_pinned_memory:
                tensor[:] = in_loop_tensor.cpu()
                in_loop_tensor = tensor
            else:
                in_loop_tensor = in_loop_tensor.cpu()
        start = time.perf_counter()
        for _ in trange(repeats, desc="test"):
            in_loop_tensor = in_loop_tensor.cuda()
            if use_pinned_memory:
                tensor[:] = in_loop_tensor.cpu()
                in_loop_tensor = tensor
            else:
                in_loop_tensor = in_loop_tensor.cpu()
        end = time.perf_counter()
        print(f"Total time taken: {end-start:.2f}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=10)
    parser.add_argument("--warmup", type=int, default=5)
    parser.add_argument("--repeats", type=int, default=100)
    parser.add_argument("--frame_shape", type=int, nargs=3, default=(3, 3840, 2160))
    parser.add_argument("--use_pinned_memory", type=bool, default=True)
    parser.add_argument('--no_pin', dest='use_pinned_memory', action='store_false')
    args = parser.parse_args()

    args = dict(vars(args))
    print(args)
    stress_vram_transfer(**args)

Using this, and running on an RTX3090, this is what I get:

$ CUDA_VISIBLE_DEVICES=1 python memory_transfer_test.py --batch_size 50 --no_pin
{'batch_size': 50, 'warmup': 5, 'repeats': 100, 'frame_shape': (3, 3840, 2160), 'use_pinned_memory': False}
Starting test for device 0: _CudaDeviceProperties(name='NVIDIA GeForce RTX 3090', major=8, minor=6, total_memory=24268MB, multi_processor_count=82)
warmup: 100%|_____________________________________________________________________________________________________________________________________________| 5/5 [00:13<00:00,  2.71s/it]
test: 100%|___________________________________________________________________________________________________________________________________________| 100/100 [03:42<00:00,  2.22s/it]
Total time taken: 222.21
$ CUDA_VISIBLE_DEVICES=1 python memory_transfer_test.py --batch_size 50
{'batch_size': 50, 'warmup': 5, 'repeats': 100, 'frame_shape': (3, 3840, 2160), 'use_pinned_memory': True}
Starting test for device 0: _CudaDeviceProperties(name='NVIDIA GeForce RTX 3090', major=8, minor=6, total_memory=24268MB, multi_processor_count=82)
warmup: 100%|_____________________________________________________________________________________________________________________________________________| 5/5 [00:13<00:00,  2.64s/it]
test: 100%|___________________________________________________________________________________________________________________________________________| 100/100 [04:23<00:00,  2.63s/it]
Total time taken: 263.24

These results are surprising on several fronts:

  1. The memory transfer speed is MUCH slower than what the hardware promises: 6770MB vRAM usage by my process is being reported by nvidia-smi, and each step is taking about 2.22s transferring it both ways, which would equate to 2*6770MB/2.22s = 6100MB/s. An RTX3090 is supposed to have 936.2 GB/s memory bandwidth: even if I divide that by 4 accounting for my card only having access to 4x PCIE lanes VS the maximum of 16x, I am still faced with at least a 10x discrepancy in that memory bandwidth.
  2. The use_pinned_memory version actually performs slower than the no_pin version. Looking at the PyTorch documentation, pin_memory() is recommended for faster Host to GPU copies…

Given these observations, I have the following questions:

  1. Am I wrong to expect the memory bandwidth advertised for my GPU to match up with the tensor transfer speeds in PyTorch?
  2. Am I using pin_memory() correctly in this test? From my understanding, a major advantage this function brings is the asynchronous GPU copies, but that does not seem like it can be exploited in this scenario.
  3. If using pin_memory() is supposed to help with “Host to GPU copies”, does it also apply to GPU to Host copies? Or is there a separate trick to use?
  4. If the information from point 3 on this page is to be believed, PyTorch will actually copy non-pinned tensors to pinned memory before copying it to GPU. Is there any reason we could expect that to be slower than manually copying the tensor to pinned memory, then asking PyTorch to copy it to GPU?
2 Likes

I definitely want to pay attention to any responses on this. I say that because I’ve noticed the trend that as the GPU’s and the software and the algorithms used in the software get faster and faster then simple things like saving a inference generated image in Stable Diffusion becomes a significant part of the overall time. And it is no different here. If a result is computed very quickly then the transfer to and from the GPU becomes significant. I have a 4090 and a fast i9-13900K cpu so I want to see if this memory pinning helps.

Prior to starting in this new field of AI and GPU’s I was a perf expert in conventional systems. There was a similar memory pinning thing done for things like RDMA. The trick was having enough data to transfer such that the overhead of pinning/unpinning memory didn’t wipe out the gain. I just beginning to learn this stuff so I don’t know if NVIDIA has any overhead with pinning. Certainly pinning memory pages on the CPU’s memory does add some overhead. The trick I learned was to reserve a pinned memory region once and then keep reusing it instead of doing the pin for individual transfer operations. I wonder if a similar concept exists here. Hmmm, silly me… I just noticed you actually are doing that.

I’ll experiment with this myself. The torch doc’s I just looked at don’t say whether the memory pinning applies to CPU or GPU memory. Since I’m on Linux I can simply pin the entire memory space of the app process to guarantee it is pinned. Yes, I know that is dangerous if a process is too big. But it is an experiment to see if there is any perf impact and I know what I’m doing. Another technique would be to strace the process to see if there are many pin/unpin calls for each transfer being done. But that’ll be for tomorrow.

1 Like

This looks like it’ll be a good read for me to better understand the subject. Too tired right now…

  1. I would recommend reading through the linked blog post about memory transfers and and to run a few benchmarks if you are interested in profiling your system (without PyTorch to reduce the complexity of the entire stack).

  2. Using pinned memory would avoid a staging copy and should perform better and also allow you to use non blocking data transfers as also explained in the blog post.

  3. Yes, using pin_memory=True will allow you to use non blocking copies allowing you to overlap the data transfer with another operation. However, if the very next operation depends on the transferred tensor there won’t be any overlapping operation so I’m unsure what your expectations in your test would be.

  4. Yes, device to host copies can also use pinned memory, but you would need to synchronize the device before accessing the data.

  5. No, unless your overall system performance suffers from too much pinned memory by the process and starts to swap memory.

Here is a small example showing a non blocking D2H copy and shows the common mistake causing corrupt data by not properly synchronizing the code:

stream = torch.cuda.current_stream()
dst = "cpu"
non_blocking = True
a = torch.randn(1000000, device="cuda")
ref = a.cpu()

torch.cuda.synchronize()
# Pushes an 2 second spin to stream so if the copy is non blocking,
# stream will almost surely be active when we query().
torch.cuda._sleep(int(2000 * torch.testing._internal.common_utils.get_cycles_per_ms()))
b = a.to(device=dst, non_blocking=non_blocking)
print(stream.query(), not non_blocking)
# False False
print((b - ref).abs().max())
# tensor(7.3916)
stream.synchronize()
print((b - ref).abs().max())
# tensor(0.)
print(b.is_pinned())
# True
1 Like

Thank you very much for this thorough reply. This is very enlightening.

1: The benchmark in the blog post @aifartist linked to does seem very relevant to test the system’s ability without PyTorch in the mix. I’d like to run it, but I am not familiar with C++, so I will need some time to figure out how to set up the correct environment and compile it. Any pointer will be appreciated.

2 3 4 5: The .to(device="cpu", non_blocking=True) method is indeed useful even in my test, which does not leverage the actual non-blocking nature of the calls (I assume that’s thanks to avoiding a pinned ==> paged ==> pinned round trip). Is this behavior of returning pinned CPU tensors mentioned anywhere in the documentation, however?

6: I assume allocating pinned memory takes a bit of time. Would I be wrong to hope for some increase in speed if I manage to recycle the same pinned memory buffer repeatedly?

7: If yes, is there a way to ensure that in PyTorch currently? Tensor.to() seems to be lacking the out= argument found in many other operations such as torch.add()

For your reference, here is the version of my code after updating it with non_blocking:

import argparse
import time

import torch
from tqdm import trange


def stress_vram_transfer(
        batch_size=10,
        warmup=5,
        repeats=100,
        frame_shape=(3, 3840, 2160),
        use_pinned_memory=True,
):
    tensor = torch.randn((batch_size, *frame_shape))
    if use_pinned_memory:
        tensor = tensor.pin_memory()

    for device_id in range(torch.cuda.device_count()):
        print(f"Starting test for device {device_id}: {torch.cuda.get_device_properties(device_id)}")
        for _ in trange(warmup, desc="warmup"):
            tensor = tensor.to(device=device_id, non_blocking=use_pinned_memory)
            tensor = tensor.to(device="cpu", non_blocking=use_pinned_memory)
            if use_pinned_memory:
                torch.cuda.current_stream(device=device_id).synchronize()
        start = time.perf_counter()
        for _ in trange(repeats, desc="test"):
            tensor = tensor.to(device=device_id, non_blocking=use_pinned_memory)
            tensor = tensor.to(device="cpu", non_blocking=use_pinned_memory)
            if use_pinned_memory:
                torch.cuda.current_stream(device=device_id).synchronize()
        end = time.perf_counter()
        print(f"Total time taken: {end-start:.2f}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=10)
    parser.add_argument("--warmup", type=int, default=5)
    parser.add_argument("--repeats", type=int, default=100)
    parser.add_argument("--frame_shape", type=int, nargs=3, default=(3, 3840, 2160))
    parser.add_argument("--use_pinned_memory", type=bool, default=True)
    parser.add_argument('--no_pin', dest='use_pinned_memory', action='store_false')
    args = parser.parse_args()

    args = dict(vars(args))
    print(args)
    stress_vram_transfer(**args)

The time to run the test went from 263.24s down to 157.48s, beating the no_pin version’s 222.21s, and bringing the memory transfer speed to roughly 8600MB/s. That is better, but still far behind the hundreds of GB/s I was hoping for.

@ptrblck

If I “strace -f” my GPU app I find not a single mlockxxx call nor a mmap with MAP_LOCKED call.

I presume there isn’t a major bug in pytorch where they are missing doing the memory pin such that crashes might occur under stress if the system starts swapping pages. I only know of the two ways mentioned to pin memory pages.

PyTorch uses cudaHostAlloc, but you can also allocate cudaHostRegister via torch.cuda.cudart() in case you want to manage the host memory with MAP_LOCKED manually.

1 Like

Regarding this, I found out that Tensor.copy_ enables exactly that: If you pre-allocate the pinned tensor, you can use pinned_tensor.copy_(non_blocking=True) to copy the CUDA tensor directly into it, and vice versa with copying pinned tensors into CUDA tensors.

However, when testing out this optimization, I saw no measurable performance increase. I suspect this might be because PyTorch already uses a caching memory allocator that efficiently takes care of this in the background, at least for this simple test case.