Introducing SpeedTorch: 4x speed CPU->GPU transfer, 110x GPU->CPU transfer

This is my submission to the Pytorch Hackathon. It’s a library I made for Pytorch, for fast transfer between pinned CPU tensors and GPU pytorch variables. The inspiration came from needed to train large number of embeddings, which don’t all fit on GPU ram at a desired embedding size, so I needed a faster CPU <-> GPU transfer method. This also allows using any optimizer for sparse training, since every embedding contained in the Pytorch embedding variable receives an update, previously only Pytorch’s SGD, Adagrad, and SparseAdam were suitable for such training.

In addition to augmenting parameter sizes, you can use to increase the speed of which data on your CPU is transferred to Pytorch Cuda variables.

Also, SpeedTorch’s GPU tensors are also overall faster then Pytorch cuda tensors, when taking into account both transferring two and from (overall 2.6x faster). For just transfering to a Pytorch Cuda, Pytorch is still faster, but significantly slower when transfering from a Pytorch Cuda variable.

I have personally used this to nearly double the embedding size of embeddings in two other projects, by holding half the parameters on CPU. The training speed is decent thanks to the fast CPU<->GPU exchange.

There’s a bit of a learning curve for the very first time getting started with it, so as soon as you run into any sort of friction, feel free to ask a question on the project gitter

https://gitter.im/SpeedTorch

And I’ll answer them.

2 Likes

Interesting work. This can speed-up operations like:

cpu_tensor[idx] = gpu_tensor.cpu()

and

gpu_tensor = cpu_tensor[idx].cuda()

The library does this by masquerading the pinned-memory CPU tensor as a cupy GPU tensor and using the cupy GPU indexing kernels. It’s not really speeding up the CPU-GPU copy part, it’s avoiding the overhead of a separate CPU indexing operation. (The GitHub page mentions “memmaps”, but it’s not using memory mapping at all).

For CPU <-> GPU copies, the relative speed-up depends on how slow the PyTorch CPU indexing operation is. For the (131072, 128) size, this can be as fast as 2 ms on a multi-core machine or ~25 ms on collab notebook with only one or two virtual CPUs. For comparison, the CPU-GPU copy from pinned memory is consistent at ~5-6 ms. (Again, it’s not the CPU-GPU copy that’s sped up; it’s the separate CPU indexing operation that’s avoided)

There’s also a significant indexing performance bug in PyTorch 1.1 and 1.2 (fixed in the nightly builds) that makes the indexing much slower. The apparent GPU <-> GPU indexing speed-ups are entirely due to this bug. You could either use the nightly builds or use index_select / index_copy_ instead of a[idx] notation in 1.1/1.2 to avoid that slow down in vanilla PyTorch. (Bug is described in #24083 for reference)

The library is missing some synchronization. Particularly, when copying from GPU to pinned memory (masquerading as GPU via cupy), you need to synchronize before accessing the CPU data; otherwise it may not be consistent.

There’s a few bugs in the benchmark code, mostly minor:

  1. sampl = np.random.uniform(low=-1.0, high=1.0, size=(1000000, 128)). This is float64; it should be float32 to match everything else:
    sampl = np.random.uniform(low=-1.0, high=1.0, size=(1000000, 128)).astype(np.float32)

  2. The timing code should have a torch.cuda.synchronize() at the end. Instead of:

torch.cuda.synchronize()
cupy.cuda.Device().synchronize()

runningTime=0
for i in range(numSamples):
    start = time.time()

    with torch.no_grad():

        gadgetCPU.insertData(u_embeddings.weight.data, indexess )

    end = time.time()
    runningTime = runningTime + end - start

print('set corpus. cupy live pinned')
print(runningTime/numSamples)

It should be written as:

torch.cuda.synchronize()
cupy.cuda.Device().synchronize()

start = time.time()
for i in range(numSamples):
    with torch.no_grad():
        gadgetCPU.insertData(u_embeddings.weight.data, indexess )

torch.cuda.synchronize()
end = time.time()
runningTime = end - start

print('set corpus. cupy live pinned')
print(runningTime/numSamples)
2 Likes

Thanks for the detailed analysis.
I had a quick look at the implementations and besides the missing synchronizations, I’m still trying to figure out, if the CUPY arrays can hold any CPU data, or if it’s restricted to the device only?
E.g. while this line of code

gadgetCPU.insertData(u_embeddings.weight.data,  indexess )

seems to copy the CUDATensor from u_embeddings onto the CPU, it seems as if the underlying data never left the GPU:

gadgetCPU.CUPYcorpus.device
Out[4]: <CUDA Device 0>

I’m not really experienced in CUPY so I might miss something.

1 Like

As far as I can tell, CuPy is only intended to hold CUDA data, but in this case it’s actually holding CPU data (pinned memory). You can check with something like:

cupy.cuda.runtime.pointerGetAttributes(gadgetCPU.CUPYcorpus.data.ptr).memoryType

This will print 1 (= cudaMemoryTypeHost). On gadgetGPU it’ll print 2 (=cudaMemoryTypeDevice). (cudaMemoryType reference)

You can do something similar in PyTorch from the C++ API using torch::from_blob. Here’s an example. Note there’s a check in from_blob that tries to prevent this sort of thing, but the check is broken (what luck!).

1 Like

Thanks for the detailed analysis! I was very curious about what was going on.

The GitHub page mentions “memmaps”, but it’s not using memory mapping at all

Ahhh that’s a vestige from my initial approach to using memmaps, I overlooked changing that, it’s fixed now.

I updated the benchmarking code, and I’ll link this analysis to the ‘how it works’ section.

I’ll also include a section for when to use SpeedTorch.

you need to synchronize before accessing the CPU data; otherwise it may not be consistent.

So adding torch.cuda.synchronize() and cupy.cuda.Device().synchronize() before the transfer? By consistency, do you mean in terms transfer times?