Consider the following code:
import time
import torch
if __name__ == '__main__':
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
x = torch.rand(32, 256, 220, 220).cuda()
t = (x.min() - x.max()).to(torch.device("cpu"), non_blocking=True)
print(t)
time.sleep(2.)
print(t)
and it will print:
tensor(0.)
tensor(-1.0000)
as in the first print, the data is not transmitted to host yet. My question is, is there some way to synchronize with it? In particular, is there something I can do with CUDA stream and Event?
1 Like
Which PyTorch version are you using as I cannot reproduce it in the latest release?
maybe that’s because I have my GPU on a slow PCIE x4 link, and thus slowing down the transmission. Maybe you can increase the size of x?
I’m using 1.12
Yes, you are right and you would need to synchronize the current stream e.g. via:
if __name__ == '__main__':
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
stream = torch.cuda.current_stream()
x = torch.rand(32, 256, 220, 220).cuda()
t = (x.min() - x.max()).to(torch.device("cpu"), non_blocking=True)
print(stream.query()) # False - work not done yet
stream.synchronize() # wait for stream to finish the work
print(t)
time.sleep(2.)
print(stream.query()) # True - work done
print(t)
1 Like
Thank you. Yes for the record, here is another example that uses cuda events:
import time
import torch
class Timer:
def __init__(self):
self.start = time.monotonic()
def __call__(self):
k = time.monotonic()
v = k - self.start
self.start = k
return v
if __name__ == '__main__':
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# print(torch.cuda.current_stream())
x = torch.ones((32, 256, 220, 220), pin_memory=True)
tim = Timer()
c = torch.empty((2, 32, 256, 220, 220), device='cuda')
print(tim())
# x = x.to(torch.device('cuda'), non_blocking=True)
print(tim())
c[0, :, :, :, :].copy_(x, non_blocking=True)
print(tim())
# t = (x.min() - x.max()).to(torch.device("cpu"), non_blocking=True)
t = c[0].min()
print('mark0', tim())
t = t.to('cpu', non_blocking=True)
print('mark', tim())
ev = torch.cuda.Event()
ev.record()
# print(torch.cuda.current_stream())
print(tim())
print(t)
ev.synchronize()
print(tim())
print(t)
You will observe that only the last operation ev.synchronize
takes substantial amount of time. All other operations are almost instant.
1 Like