Slow CPU<=>GPU transfer

Hi,

I have issues with cpu<=>gpu transfer which is extremely slow in pytorch using V100.
I have a cpu to gpu transfer test script. When pinning memory and just transfering, the transfer time is normal:

import torch
import numpy
import torch.nn as nn
from torch.autograd import Variable
import time
import sys

rank = 0


N = 6000
H = 4000
K = 10

gpu_device = torch.device("cuda:"+str(rank))

w1 = torch.randn(N, H, device=torch.device("cpu"))
w2 = torch.randn(N, H, device=torch.device("cpu"))
w3 = torch.randn(N, H, device=torch.device("cpu"))
w4 = torch.randn(N, H, device=torch.device("cpu"))
w5 = torch.randn(N, H, device=torch.device("cpu"))
w6 = torch.randn(N, H, device=torch.device("cpu"))
w7 = torch.randn(N, H, device=torch.device("cpu"))
w8 = torch.randn(N, H, device=torch.device("cpu"))

ww1 = torch.nn.Parameter(w1).pin_memory()
ww2 = torch.nn.Parameter(w2).pin_memory()
ww3 = torch.nn.Parameter(w3).pin_memory()
ww4 = torch.nn.Parameter(w4).pin_memory()
ww5 = torch.nn.Parameter(w5).pin_memory()
ww6 = torch.nn.Parameter(w6).pin_memory()
ww7 = torch.nn.Parameter(w7).pin_memory()
ww8 = torch.nn.Parameter(w8).pin_memory()

x1 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x2 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x3 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x4 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x5 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x6 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x7 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x8 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)

# ww1 = ww1+1
# ww2 = ww2+1
# ww3 = ww3+1
# ww4 = ww4+1
# ww5 = ww5+1
# ww6 = ww6+1
# ww7 = ww7+1
# ww8 = ww8+1

torch.cuda.synchronize()
t_start = time.time()

for i in range(K):

     x1.copy_(ww1.data, non_blocking=True)
     x2.copy_(ww2.data, non_blocking=True)
     x3.copy_(ww3.data, non_blocking=True)
     x4.copy_(ww4.data, non_blocking=True)
     x5.copy_(ww5.data, non_blocking=True)
     x6.copy_(ww6.data, non_blocking=True)
     x7.copy_(ww7.data, non_blocking=True)
     x8.copy_(ww8.data, non_blocking=True)

torch.cuda.synchronize()
t_end = time.time()

t_time = (t_end - t_start)*1000
t_size = K*N*H*4*8/1024/1024
t_bw = t_size/(t_end-t_start)


print('size of transfer in both directions  ', t_size, 'MB')	
print('time taken by transfer  ', t_time, 'mSec')	
print('Effective bandwidth ', t_bw, 'MBps')	

I get 11GB/sec bandwidth here. But when I change the tensors before transfer, it becomes extremely slow:

import torch
import numpy
import torch.nn as nn
from torch.autograd import Variable
import time
import sys

rank = 0


N = 6000
H = 4000
K = 10

gpu_device = torch.device("cuda:"+str(rank))

w1 = torch.randn(N, H, device=torch.device("cpu"))
w2 = torch.randn(N, H, device=torch.device("cpu"))
w3 = torch.randn(N, H, device=torch.device("cpu"))
w4 = torch.randn(N, H, device=torch.device("cpu"))
w5 = torch.randn(N, H, device=torch.device("cpu"))
w6 = torch.randn(N, H, device=torch.device("cpu"))
w7 = torch.randn(N, H, device=torch.device("cpu"))
w8 = torch.randn(N, H, device=torch.device("cpu"))

ww1 = torch.nn.Parameter(w1).pin_memory()
ww2 = torch.nn.Parameter(w2).pin_memory()
ww3 = torch.nn.Parameter(w3).pin_memory()
ww4 = torch.nn.Parameter(w4).pin_memory()
ww5 = torch.nn.Parameter(w5).pin_memory()
ww6 = torch.nn.Parameter(w6).pin_memory()
ww7 = torch.nn.Parameter(w7).pin_memory()
ww8 = torch.nn.Parameter(w8).pin_memory()

x1 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x2 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x3 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x4 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x5 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x6 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x7 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)
x8 = torch.randn(N, H, dtype=torch.float32, device=gpu_device)

ww1 = ww1+1
ww2 = ww2+1
ww3 = ww3+1
ww4 = ww4+1
ww5 = ww5+1
ww6 = ww6+1
ww7 = ww7+1
ww8 = ww8+1

torch.cuda.synchronize()
t_start = time.time()

for i in range(K):

     x1.copy_(ww1.data, non_blocking=True)
     x2.copy_(ww2.data, non_blocking=True)
     x3.copy_(ww3.data, non_blocking=True)
     x4.copy_(ww4.data, non_blocking=True)
     x5.copy_(ww5.data, non_blocking=True)
     x6.copy_(ww6.data, non_blocking=True)
     x7.copy_(ww7.data, non_blocking=True)
     x8.copy_(ww8.data, non_blocking=True)

torch.cuda.synchronize()
t_end = time.time()

t_time = (t_end - t_start)*1000
t_size = K*N*H*4*8/1024/1024
t_bw = t_size/(t_end-t_start)


print('size of transfer in both directions  ', t_size, 'MB')	
print('time taken by transfer  ', t_time, 'mSec')	
print('Effective bandwidth ', t_bw, 'MBps')	

Here bandwidth is dropped to 3-4GB/sec!
Is that expected? How can I optimize the transfer time between gpu and cpu?

PS: I tested pure cuda transfers and I get 12GB/s which is expected. Also, without using pin_memory(), I get 3-4GB/sec regardless of changing the tensor or not.

1 Like

Your manipulation ww1 = ww1 + 1 will create a new tensor, which is not pinned anymore.
You can check it via print(ww1.is_pinned()), which should return False after the creation of the new tensor.

I see. So is there any way to transfer more efficiently after changing the values? Let’s say I want to transfer model weights during training (which requires changing the values). This low bandwidth is hurting that.

You could use inplace methods (e.g. ww1.add_(1.)) to keep the memory block alive.
Would that work for you or are you using more complicated (and possibly not implemented) methods?

I am running forward and backward on gpu and transfer gradients to cpu, then run optimizer on cpu and transfer back the updated weights to gpu. I see different performance on different machines and all are much lower than expected.

The optimizers should update the parameter inplace as shown here:

p = nn.Parameter(torch.ones(10).pin_memory())
print(p.is_pinned())
optimizer = torch.optim.Adam([p], lr=1.)
p.grad = torch.ones_like(p)
print(p.is_pinned())
optimizer.step()
print(p.is_pinned())

What is your use case to update the parameters on the CPU?

This is part of a multi-process parallel algorithm we are using which requires this.

Could you check the is_pinned() state before and after executing the step() method?
If you are somehow creating new tensors (and thus reallocate new memory), you would need to pin them again.

I checked it and is_pinned() is true before and after optimizer. But I don’t see any difference in performance with and without pin_memory().

I found this thread interesting. On a related note:

cuda(non_blocking=True)

has a quite significant delay (~50% of a blocking call)
Is there anyway to reduce the delay of the non-blocking call?
This is critical to the data transfer from CPU to multiple GPUs. Ideally we’d like to do data transfers in parallel, e.g.

x = torch.randn(1000, 1000, 1000)
xs = [x.cuda(i, non_blocking=True) for i in range(8)]