Measuring GPU tensor operation speed

Hi,

I would like to illustrate the speed of tensor operations on GPU for a course.

The following piece of code:

x = torch.cuda.FloatTensor(10000, 500).normal_()
w = torch.cuda.FloatTensor(200, 500).normal_()

a = time.time()
y = x.mm(w.t())
b = time.time()
print('batch GPU {:.02e}s'.format(b - a))

a = time.time()
y = x.mm(w.t())
b = time.time()
print('batch GPU {:.02e}s'.format(b - a))

prints

batch GPU 1.06e-01s
batch GPU 3.43e-04s

so I presume that there is some “lazy operations” delayed until the first mm, and that the first timing includes some memory allocation or copy or something.

Is that the case? If yes, is there a proper way to force all this to be done, some sort of cuda.flush()?

Cheers,

1 Like

torch.cuda.synchronize(), I believe

Tried that, does not seem to help.

1 Like

Yes, the GPU executes all operations asynchronously, so you need to insert proper barriers for your benchmarks to be correct. Also, if you’re using Python 3, I’d recommend using time.perf_counter() instead of time.time(). Here’s a corrected script:

x = torch.cuda.FloatTensor(10000, 500).normal_()
w = torch.cuda.FloatTensor(200, 500).normal_()

# ensure that context initialization and normal_() operations
# finish before you start measuring time
torch.cuda.synchronize()
torch.cuda.synchronize()

a = time.perf_counter()
y = x.mm(w.t())
torch.cuda.synchronize() # wait for mm to finish
b = time.perf_counter()
print('batch GPU {:.02e}s'.format(b - a))

a = time.perf_counter()
y = x.mm(w.t())
torch.cuda.synchronize() # wait for mm to finish
b = time.perf_counter()
print('batch GPU {:.02e}s'.format(b - a))

That said, it still gives me some weird results. Even with proper synchronization, running this timing block in a loop gives me:

batch GPU 1.64e-01s
batch GPU 1.25e-03s
batch GPU 7.01e-04s
batch GPU 6.96e-04s
batch GPU 6.94e-04s

@ngimel any ideas what might be causing it?

8 Likes

I believe cublas handles are allocated lazily now, which means that first operation requiring cublas will have an overhead of creating cublas handle, and that includes some internal allocations. So there’s no way to avoid it other than calling some function requiring cublas before the timing loop.

That would explain the initial slowdown, but I’m wondering why it affects two iterations :confused:

The overhead on second iteration is garbage collection/reference counting not keeping up, and caching allocator allocating second y tensor. If you add y=None in the timing loop, you’d be able to reuse y allocation, and only the first iteration would show additional overhead.

2 Likes

You should warm up the GPU clock for at least a couple seconds. Also average over more iterations and estimate/subtract the fixed overhead that you can estimate from increasing the number of iterations.

So timing in Pytorch seems to be very weird. I tried the code on this thread to verify some numbers reported in various papers like Shufflenet. Here’s the code:

cudnn.benchmark = True

model = alexnet().cuda().eval()

# Shufflenet model from https://github.com/jaxony/ShuffleNet/blob/master/model.py
# model = ShuffleNet().cuda().eval()

x = torch.rand(1, 3, 224, 224).cuda()

torch.cuda.synchronize()
# torch.cuda.synchronize()

sum_time = 0
total_runs = 100
warm_up = 50

for i in range(total_runs):
    torch.cuda.synchronize()
    a = time.perf_counter()
    # y = None
    y = model(x)
    torch.cuda.synchronize()
    b = time.perf_counter()

    if i > warm_up:  # give torch some time to warm up
        sum_time += ((b-a) * 1000)

print("Time (ms): {:.04}".format(sum_time / (total_runs-warm_up)))

For Alexnet, I get ~1.327 ms but the Shufflenet model gives me ~5.403 ms, both of which are very different from the numbers reported in the Shufflenet paper.

This begs the question that is using wall clock time even a reliable source of measuring speed in Pytorch?

@ngimel @apaszke @Andrei_Pokrovsky