Introducing SpeedTorch: 4x speed CPU->GPU transfer, 110x GPU->CPU transfer

Interesting work. This can speed-up operations like:

cpu_tensor[idx] = gpu_tensor.cpu()

and

gpu_tensor = cpu_tensor[idx].cuda()

The library does this by masquerading the pinned-memory CPU tensor as a cupy GPU tensor and using the cupy GPU indexing kernels. It’s not really speeding up the CPU-GPU copy part, it’s avoiding the overhead of a separate CPU indexing operation. (The GitHub page mentions “memmaps”, but it’s not using memory mapping at all).

For CPU <-> GPU copies, the relative speed-up depends on how slow the PyTorch CPU indexing operation is. For the (131072, 128) size, this can be as fast as 2 ms on a multi-core machine or ~25 ms on collab notebook with only one or two virtual CPUs. For comparison, the CPU-GPU copy from pinned memory is consistent at ~5-6 ms. (Again, it’s not the CPU-GPU copy that’s sped up; it’s the separate CPU indexing operation that’s avoided)

There’s also a significant indexing performance bug in PyTorch 1.1 and 1.2 (fixed in the nightly builds) that makes the indexing much slower. The apparent GPU <-> GPU indexing speed-ups are entirely due to this bug. You could either use the nightly builds or use index_select / index_copy_ instead of a[idx] notation in 1.1/1.2 to avoid that slow down in vanilla PyTorch. (Bug is described in #24083 for reference)

The library is missing some synchronization. Particularly, when copying from GPU to pinned memory (masquerading as GPU via cupy), you need to synchronize before accessing the CPU data; otherwise it may not be consistent.

There’s a few bugs in the benchmark code, mostly minor:

  1. sampl = np.random.uniform(low=-1.0, high=1.0, size=(1000000, 128)). This is float64; it should be float32 to match everything else:
    sampl = np.random.uniform(low=-1.0, high=1.0, size=(1000000, 128)).astype(np.float32)

  2. The timing code should have a torch.cuda.synchronize() at the end. Instead of:

torch.cuda.synchronize()
cupy.cuda.Device().synchronize()

runningTime=0
for i in range(numSamples):
    start = time.time()

    with torch.no_grad():

        gadgetCPU.insertData(u_embeddings.weight.data, indexess )

    end = time.time()
    runningTime = runningTime + end - start

print('set corpus. cupy live pinned')
print(runningTime/numSamples)

It should be written as:

torch.cuda.synchronize()
cupy.cuda.Device().synchronize()

start = time.time()
for i in range(numSamples):
    with torch.no_grad():
        gadgetCPU.insertData(u_embeddings.weight.data, indexess )

torch.cuda.synchronize()
end = time.time()
runningTime = end - start

print('set corpus. cupy live pinned')
print(runningTime/numSamples)
2 Likes