During a profiling session, I got the impression that torch.tensor(x, device='cuda') is blocking. That is, the python script only carries on after the data is effectively on GPU. If an experiment has many small kernels, it can be difficult to refill the GPU queue and get GPU utilization up again.
If torch.tensor is indeed blocking, one should use torch.tensor(x, pin_memory=True).to('cuda') to avoid a synchronization point.
I have not been able to showcase this behaviour on a toy example, and I don’t know where to look into the PyTorch code. If someone savvy could provide some clarification, that would be great.
This is to be expected (at least by me). If x resides on the cpu, then the
cpu will be busy and “block” until it is done with its part of copying the data
to the gpu.
But I believe that torch.tensor() is non-blocking in the sense that it will
return before the gpu is done with whatever work it has to do. However,
this non-blocking behavior is hidden because the cpu data-transfer time is
the bottleneck compared to any pure gpu work (at least on my test system).
The non-blocking behavior becomes apparent when x itself resides on the
gpu (or presumably on a second gpu).
Here is a timing script:
import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())
import time
import warnings
warnings.filterwarnings ('ignore') # clean up output
for source_device in ('cpu', 'cuda'):
print ('source_device:', source_device)
for iMeg in (100, 200, 400, 800):
if source_device == 'cuda' and iMeg > 400: # avoid out of memory
break
t_source = torch.randn (iMeg, 1000, 1000, device = source_device)
# warmup
tc0 = torch.tensor (t_source, device = 'cuda')
tc0 = None
torch.cuda.synchronize() # make sure gpu is ready
t0 = time.time()
tc1 = torch.tensor (t_source, device = 'cuda')
t1 = time.time()
print ('iMeg:', iMeg, ' t_nosync: ', t1 - t0)
torch.cuda.synchronize()
tc1 = None
torch.cuda.synchronize() # make sure gpu is ready
t0 = time.time()
tc2 = torch.tensor (t_source, device = 'cuda')
torch.cuda.synchronize() # wait for torch.tensor() to actually finish
t1 = time.time()
print ('iMeg:', iMeg, ' t_sync: ', t1 - t0)
tc2 = None
t_source = None
The fact that the gpu → gpu “nosync” timings are much shorter than the
analogous “sync” timings shows that torch.tensor (x, device = 'cuda')
returns asynchronously when x resides on the gpu.
the cpu data-transfer time is the bottleneck compared to any pure gpu work (at least on my test system).
That’s my problem, it did not seem to be the case on a compute node with A100 and a fairly large model, but just like you, I cannot reproduce it locally.