Significantly difference in execution time when convolution is run as nn.Conv2d and as nn.Sequential

Hi all,
I’m experiencing a huge difference in execution runtime between a convolution run as nn.Conv2d and nn.Sequential() on aarch64 (Nvidia Jetson AGX).
I think this issue is showing up due to the lack of any optimized kernels for the specific convolution layers I’m trying. This issue is only reprodcible on aarch64 processor. But still that couldn’t explain this huge difference in execution times.

To reproduce this issue,

  1. Run a Depthwise Convolution (groups = in_channels) and Pointwise Convolution (kernel size =1) individually as nn.Conv2d() and measure their runtime.
  2. Run the same set of convolutions as a single nn.Sequential
import torch
import time
x = torch.rand([1,78, 56, 56])
#### nn.Conv2d
dconv = torch.nn.Conv2d(78, 78, 3, stride=2, padding=1,groups=78)
pconv = torch.nn.Conv2d(78, 78, 1)
delay = []
for i in range(30):
    start = time.perf_counter()
    y = dconv(x)
    end = time.perf_counter()
    delay.append((end-start)*1000)
print(sum(delay)/len(delay))
delay = []
for i in range(30):
    start = time.perf_counter()
    z = pconv(y)
    end = time.perf_counter()
    delay.append((end-start)*1000)
print(sum(delay)/len(delay))
##### nn.Module
model = torch.nn.Sequential(torch.nn.Conv2d(78, 78, 3, stride=2, padding=1,groups=78, device='cpu'), torch.nn.Conv2d(78, 78, 1, device='cpu'))
delay = []
for i in range(30):
    start = time.perf_counter()
    z = model(x)
    end = time.perf_counter()
    delay.append((end-start)*1000)
print(sum(delay)/len(delay))

Output

5.974839165961991
4.580187029205263
141.68961447236748

There’s almost 15x difference in runtime between convolution executing as individual nn.Conv2Ds and nn.Sequential. The expected output would be to have time(dconv) + time(pconv) almost similar to time(model) in the above example.

Any idea why this is happening? Thanks in advance for your replies.
I have also opened a issue regarding the same at here

Your output looks quite strange and I get:

20.529749244451523
4.179799013460676
17.342762431750696

on a Xavier.

Your output seems like its more expected. I’m also running on Xavier. I tried both PyTorch 1.9.0 and 1.8.0 and seeing the same results. (The nn.Sequential() is alway significantly higher)

This is the output I get for conseuctive runs of the program script above.

:~$ python3 test.py
11.383225899993477
3.725462199993975
176.9323086999994
:~$ python3 test.py
5.979387333337627
3.1348664666647132
77.57876950000575
:~$ python3 test.py
5.478009166661953
4.854251033331518
105.16439119999556
:~$ python3 test.py
5.193072766667228
2.5191012333304266
135.01443140000143

Any ideas to narrow down the cause of issue?

I also profiled the code (Conv2ds & nn.Sequential using Profiler) and got the following outputs for three Convs.

with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU]) as prof:
    model(x)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                  aten::conv2d         0.06%      17.000us       100.00%      26.736ms      26.736ms             1
             aten::convolution         0.03%       9.000us        99.94%      26.719ms      26.719ms             1
            aten::_convolution         4.74%       1.268ms        99.90%      26.710ms      26.710ms             1
    aten::_convolution_nogroup         4.35%       1.164ms        83.49%      22.323ms     286.192us            78
             aten::thnn_conv2d         1.56%     417.000us        78.83%      21.077ms     270.218us            78
     aten::thnn_conv2d_forward        46.99%      12.564ms        77.27%      20.660ms     264.872us            78
                  aten::narrow         3.66%     978.000us        10.84%       2.899ms      12.336us           235
                  aten::select         7.99%       2.136ms         9.32%       2.493ms      10.654us           234
               aten::unsqueeze         6.30%       1.685ms         7.59%       2.029ms       8.671us           234
                   aten::slice         5.56%       1.487ms         7.19%       1.921ms       8.174us           235
                  aten::addmm_         5.55%       1.485ms         5.55%       1.485ms      19.038us            78
              aten::as_strided         4.25%       1.135ms         4.25%       1.135ms       1.615us           703
                 aten::reshape         1.73%     463.000us         2.31%     617.000us       7.910us            78
                 aten::resize_         1.78%     477.000us         1.78%     477.000us       3.038us           157
                   aten::copy_         1.76%     471.000us         1.76%     471.000us       6.038us            78
                   aten::empty         1.41%     378.000us         1.41%     378.000us       1.609us           235
                    aten::view         1.16%     311.000us         1.16%     311.000us       1.994us           156
                     aten::cat         0.33%      88.000us         0.86%     230.000us     230.000us             1
                    aten::_cat         0.45%     121.000us         0.53%     142.000us     142.000us             1
       aten::_nnpack_available         0.31%      82.000us         0.31%      82.000us       1.051us            78
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 26.736ms

------------------------------  ------------  ------------  ------------  ------------  ------------  ------------

                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls

------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                  aten::conv2d         0.34%      11.000us       100.00%       3.214ms       3.214ms             1
             aten::convolution         0.34%      11.000us        99.66%       3.203ms       3.203ms             1
            aten::_convolution         0.50%      16.000us        99.32%       3.192ms       3.192ms             1
    aten::_convolution_nogroup         0.96%      31.000us        98.82%       3.176ms       3.176ms             1
             aten::thnn_conv2d         0.34%      11.000us        97.48%       3.133ms       3.133ms             1
     aten::thnn_conv2d_forward        19.88%     639.000us        97.14%       3.122ms       3.122ms             1
                   aten::copy_        55.54%       1.785ms        55.54%       1.785ms       1.785ms             1
                  aten::addmm_        18.20%     585.000us        18.20%     585.000us     585.000us             1
                  aten::select         0.65%      21.000us         1.28%      41.000us      13.667us             3
              aten::as_strided         0.84%      27.000us         0.84%      27.000us       4.500us             6
                   aten::empty         0.59%      19.000us         0.59%      19.000us       6.333us             3
               aten::unsqueeze         0.31%      10.000us         0.53%      17.000us       5.667us             3
                    aten::view         0.44%      14.000us         0.44%      14.000us       2.800us             5
       aten::_nnpack_available         0.37%      12.000us         0.37%      12.000us      12.000us             1
                 aten::resize_         0.34%      11.000us         0.34%      11.000us      11.000us             1
                 aten::reshape         0.28%       9.000us         0.34%      11.000us      11.000us             1
                  aten::detach         0.06%       2.000us         0.06%       2.000us       2.000us             1
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 3.214ms

------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                  aten::conv2d         0.01%      15.000us       100.00%     175.729ms      87.865ms             2
             aten::convolution         0.01%      16.000us        99.99%     175.714ms      87.857ms             2
            aten::_convolution         0.77%       1.348ms        99.98%     175.698ms      87.849ms             2
    aten::_convolution_nogroup         0.66%       1.156ms        97.28%     170.944ms       2.164ms            79
             aten::thnn_conv2d         0.28%     488.000us        96.58%     169.724ms       2.148ms            79
     aten::thnn_conv2d_forward        91.71%     161.165ms        96.31%     169.236ms       2.142ms            79
                  aten::narrow         0.60%       1.056ms         1.80%       3.169ms      13.485us           235
                   aten::copy_         1.21%       2.133ms         1.21%       2.133ms      27.000us            79
                   aten::slice         1.03%       1.810ms         1.20%       2.113ms       8.991us           235
                  aten::select         0.79%       1.394ms         0.96%       1.689ms       7.127us           237
               aten::unsqueeze         0.78%       1.373ms         0.95%       1.662ms       7.013us           237
              aten::as_strided         0.50%     887.000us         0.50%     887.000us       1.251us           709
                  aten::addmm_         0.50%     874.000us         0.50%     874.000us      11.063us            79
                 aten::reshape         0.29%     516.000us         0.42%     734.000us       9.291us            79
                 aten::resize_         0.29%     504.000us         0.29%     504.000us       3.190us           158
                    aten::view         0.22%     386.000us         0.22%     386.000us       2.398us           161
                   aten::empty         0.18%     310.000us         0.18%     310.000us       1.303us           238
                     aten::cat         0.06%     106.000us         0.14%     246.000us     246.000us             1
                    aten::_cat         0.07%     123.000us         0.08%     140.000us     140.000us             1
       aten::_nnpack_available         0.04%      64.000us         0.04%      64.000us       0.810us            79
                  aten::detach         0.00%       5.000us         0.00%       5.000us       5.000us             1
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 175.729ms

Hi !
Could you let me know the version of PyTorch you tried on?
Is there any other way to narrow down this issue to a bug in PyTorch or some performance drop on my Xavier board.

Thanks in advance.

I’ve used a recent source build to test it. Unsure which particular commit it was, but should be ~1 month old.

@ptrblck

Interestingly, I’m not observing this issue when I download the aarch-64 pre-built wheels from here https://download.pytorch.org/whl/torch_stable.html.

Seems like I’m missing something while building from source.