PyTorch CPU overhead of creating conv2d layers

Hi All,

I’m comparing two networks: a single large convolution and a bottleneck block consisting of 3 (example A) or 2 (example B) convolutions. Profiling their feed forward runtime time with python (with appropriate torch.cuda.synchronize() calls, via python -m bottleneck.py and nvprof), runtimes are not even close to the flop count prediction. In fact, the smaller convolutions are slower than the large one. There seems to be an CPU overhead in torch.conv2d that’s not CPU-GPU communication, so each additional layer adds overhead.

Example A: flop count speedup = 10.

input size: (64, 160, 120)

baseline_layer = torch.nn.Conv2d(64, 64, kernel_size=(3, 3), bias=False)
distilled_layer = torch.nn.Sequential(
torch.nn.Conv2d(64, 14, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(14, 15, kernel_size=(3, 3), bias=False),
torch.nn.Conv2d(15, 64, kernel_size=(1, 1), bias=False))

Example B: flop count speedup = 85.3.

input size: (256, 160, 120)

baseline_layer2 = torch.nn.Conv2d(256, 128, kernel_size=(1, 1), bias=False)
distilled_layer2 = torch.nn.Sequential(
torch.nn.Conv2d(256, 1, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1, 128, kernel_size=(1, 1), bias=False))

Timing

Times are in ms (approximate, averaged over 10000 feed forwards):

Example        GPU Time   torch.conv2d time total elapsed 
               (nvprof)   (cProfile)        time
A - baseline       .12           .08         .2
A - distilled      .15           .16         .3
B - baseline       .27           .05         .34
B - distilled      .18           .12         .29

The torch.conv2d overhead becomes more pronounced in a larger network like ResNet containing many such bottleneck blocks.

I understand the GPU time won’t reflect the op count if the image / convolution are not large enough, as it is highly optimized. But is it possible to remove the torch.conv2d overhead associated with adding more layers? Would that require writing a custom C++ or CUDA extension?

This is because the algorithm of the CPU and GPU implementation are different. I think the CPU implementation is based on the cross-correlation algorithm you often find in textbooks (i.e, the discrete version of it). The GPU approximates convolutional operations using fast fourier transforms. think the GPU/CUDA is using some version of winnograd’s method.

E.g., a good overview can be found in Lavin, A., & Gray, S. (2016). Fast algorithms for convolutional neural networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 4013-402), https://arxiv.org/abs/1509.09308

Hi Sebastian,

Thank you so much for your reply. To clarify, I’m not comparing times on the CPU and GPU, I’m only working with the GPU implementation, but there seems to be overhead in the PyTorch code that calls the GPU implementation that I’m trying to understand/remove. Does this make sense?

Oh I see, that’s maybe related to the fact that CUDA uses different algorithms for convolution. In fact, the algorithm is determined automatically during runtime, and sometimes the choice is not ideal. (Also, CUDA does some internal optimization, so for the smaller convolutions, it can maybe not decide which convolution algorithm should be used and it is switching back and forth and is re-estimating and re-optimizing).

Have you set the deterministic flag for the benchmarking? I.e.,

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

Actually, since my input size is fixed, based on What does torch.backends.cudnn.benchmark do? I am using

torch.backends.cudnn.benchmark = True

I tried both and bechmark=True gives very slightly faster times than

torch.backends.cudnn.deterministic = True

There’s a lot going on here. First, you’re not going to see an 85x (or 10x) speed-up here at these sizes. GPUs have a lot more compute (FLOPs) than memory bandwidth. These bottleneck convolutions trade-off flops for memory bandwidth (each convolution has to read the entire input and write the entire output). In the baseline cases, the compute dominates memory access, but that’s less true in the bottleneck layers. (This part isn’t a matter of CPU overhead).

Second, GPU applications, including PyTorch rely on hiding CPU overhead by executing GPU operations asynchronously. The CPU and GPU operations overlap, so as long as the CPU overhead take less time than GPU operations (and synchronizations are infrequent enough) the program is limited by the GPU. It looks like you’re synchronizing after every layer, which results in the overhead being added instead of overlapped. The fix is move the synchronization outside the loop in benchmarking code.

Third, profilers (nvprof and autograd profiler) add CPU overhead. They can be really useful for determining the execution time of GPU kernels, but may be too pessimistic in terms of estimating CPU overhead.

Fourth, it looks like you are running with batch size 1. If you want a performance improvement, batch your inputs. This can help in two ways: the GPU kernels are often more efficient with larger batch size and it amortizes overhead across multiple inputs.

I ran your code with my script.py. I ran it both with and without nvprof on a P100. The results I got are below. “Total Time” is the time printed by the script. “GPU kernel time” is the sum of the average execution times of the kernels reported by nvprof.

It looks like the CPU overhead matters in “A - Distilled” case, but not the others. It might not be a problem if the rest of your network is not CPU bound (due to latency hiding and asynchronous execution). What else can you do about it? Well, that depends on your use case (inference-only? can you batch?)

Example Total Time GPU kernel time
A - Baseline 124 us 122 us
A - Distilled 123 us 71 us (***)
B - Baseline 214 us 207 us
B - Distilled 99 us 96 us
1 Like

Thank you so much for your great response! This is very useful.

  • In the application we are restricted to batch_size = 1. Feed-forward only, so as in your script, we turn off torch grads.
  • Synchronization was actually performed after the 3 layers, using a Sequential as in your code.
  • It definitely makes sense that for small inputs and kernel sizes the overheads prevent seeing the flop count speedup in runtime.

My only remaining question is: why is memory access dominant in the bottleneck case? True, each convolution layer has to read and write the output, but the inputs and outputs of the layers are smaller than the baseline. Isn’t the number of memory accesses proportional to the number of mult-adds, as we need to read each input when we are using it in the convolution, so shouldn’t the overall memory access always be proportional to the compute time, and therefore we should be seeing more speed-up (that depends on the convolution compute time only)?

Thank you so much again.