Single Kernel for Depthwise Convolution on ARM64

Hello All,

I ran the following code on Nvidia Jetson AGX with 8 core CPU with PyTorch built from source.
If you see the profiling traces, aten::conv2d calls aten::_convolution_nogroup 384 (Number of Channel times).
On x86 CPUs, MKL DNN is used and it provides an optimal execution for grouped convolution with single kernel call. I tried to build my Pytorch from with XNNPACK, but it didn’t improve my performance. Is there any other library which fixes this issue? Any help is much appreciated :slight_smile:

import torch

x = torch.rand([1,384, 56, 56])
conv = torch.nn.Conv2d(384, 384, 3, groups=384)
conv(x)
conv(x)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU]) as prof:
    conv(x)
prof.export_chrome_trace(f"execTrace.json")
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  aten::conv2d         0.01%       9.000us       100.00%     121.898ms     121.898ms             1  
             aten::convolution         0.01%      10.000us        99.99%     121.889ms     121.889ms             1  
            aten::_convolution         5.23%       6.370ms        99.98%     121.879ms     121.879ms             1  
    aten::_convolution_nogroup         2.58%       3.140ms        79.55%      96.964ms     252.510us           384  
             aten::thnn_conv2d         1.36%       1.656ms        76.65%      93.430ms     243.307us           384  
     aten::thnn_conv2d_forward        47.47%      57.860ms        75.29%      91.774ms     238.995us           384  
                  aten::narrow         3.34%       4.071ms        11.87%      14.465ms      12.546us          1153  
                  aten::addmm_         8.60%      10.482ms         8.60%      10.482ms      27.297us           384  
                   aten::slice         7.28%       8.874ms         8.54%      10.413ms       9.031us          1153  
                  aten::select         3.62%       4.415ms         4.46%       5.431ms       4.714us          1152  
               aten::unsqueeze         3.26%       3.977ms         4.20%       5.123ms       4.447us          1152  
                   aten::copy_         3.54%       4.315ms         3.54%       4.315ms      11.237us           384  
                     aten::cat         0.47%     568.000us         3.36%       4.093ms       4.093ms             1  
              aten::as_strided         3.02%       3.682ms         3.02%       3.682ms       1.065us          3457  
                    aten::_cat         2.86%       3.484ms         2.89%       3.525ms       3.525ms             1  
                 aten::reshape         1.94%       2.367ms         2.68%       3.272ms       8.521us           384  
                 aten::resize_         2.18%       2.660ms         2.18%       2.660ms       3.459us           769  
                    aten::view         1.48%       1.799ms         1.48%       1.799ms       2.342us           768  
                   aten::empty         1.45%       1.765ms         1.45%       1.765ms       1.531us          1153  
       aten::_nnpack_available         0.32%     394.000us         0.32%     394.000us       1.026us           384  
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 121.898ms