Why is tensor.pin_memory().cuda() significantly slower?

I compare the time of tensor.pin_memory().cuda() with tensor.cuda(), the first one is significantly slower. It is assumed that tensor.cuda() will copy data from pageable memory to pinned memory first, so it is should be of similar speed, but it is almost 10 times slower.

#include <torch/extension.h>
#include <vector>
#include <thread>
#include <mutex>
#include <sys/time.h>

// Function to pin a single tensor
void pin_tensor(torch::Tensor& tensor) {
    tensor = tensor.pin_memory();
}

// Custom op to pin a list of tensors in parallel
std::vector<torch::Tensor> pin_tensors_parallel(const std::vector<torch::Tensor>& tensors) {
    struct timeval ts0, ts1;
    gettimeofday(&ts0, NULL);

    std::vector<torch::Tensor> pinned_tensors(tensors.size());
    torch::Device cuda_device(torch::kCUDA, 0);

    for (size_t i = 0; i < tensors.size(); ++i) {
        // to pin memory first
        pinned_tensors[i] = tensors[i].pin_memory().to(cuda_device);

        // direct copy to cuda
        // pinned_tensors[i] = tensors[i].to(cuda_device);
    }


    torch::cuda::synchronize();
    gettimeofday(&ts1, NULL);

    double ts = (ts1.tv_sec - ts0.tv_sec) * 1000 * 1000 + (ts1.tv_usec - ts0.tv_usec);

    printf("total time=%.3g\n", ts / 1000.0);

    return pinned_tensors;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("pin_tensors_parallel", &pin_tensors_parallel, "Pin a list of tensors in parallel");
}
import torch
import my_custom_op
import time

# Create a list of tensors
tensors = [torch.randn(1000, 1000) for _ in range(10)]

# the first time is always slower
tensors[0].pin_memory()

cuda_tensors = my_custom_op.pin_tensors_parallel(tensors)

torch.cuda.synchronize(device=0)

tensor.pin_memory().cuda() takes around 100 ms;
tensor.cuda() takes around 15ms.

I assume this is due to introducing an additional move/copy.

When a tensor is allocated CPU side it can sit pretty much anywhere in RAM (an oversimplification, of course) but pinned memory is a special region of memory that can facilitate more efficient movement to and from the GPU.

When you first create the tensors Torch has no idea that you might want them moved to a GPU at some point, so it isn’t placed in pinned memory. However when you call pin_memory() the tensor is moved to this pinned region, which can take some to allocate. You then call cuda() which copies it to the GPU.

Skipping the pinning step skips this extra allocation which makes the operation faster overall. BUT if you measure the time after pinning is completed the move to your GPU would be faster than if pinning hadn’t occurred at all.

Pinning is useful when “setting up ahead of time” and all you care about is the final copy to GPU step… but not so much if you’re allocating CPU side and then immediately moving to the GPU.

In the above code, both measures the time from cpu pageable memory to gpu memory, but why pin_memory first slows down so much(100ms vs 15ms)? It should be of similar speed as tensor.cuda().

After all, the tensor.cuda() method also copies data to pin memory, then use DMA to copy data to gpu memory.

@ptrblck

Apparently 2x slowdown is expected: A guide on good usage of non_blocking and pin_memory() in PyTorch — PyTorch Tutorials 2.6.0+cu124 documentation

Reading more of the details from that link, I suspect that the bigger slowdown you’re seeing is because tensor.cuda() is non-blocking by default, whereas tensor.pin_memory() is blocking by default, and you’re doing a bunch of little tensors which amplifies the overhead.

.cuda is default to be blocking, @ptrblck needs your explanations