Modifying GPU operators

Hello, I would like to know if it is possible to modify the operations implemented by PyTorch on the GPU that are used in pre-trained models like ResNet50 (resnet50 — Torchvision 0.20 documentation).

For example, I would like to know where the source code for operations like convolution or batch normalization is located when PyTorch performs inference on the GPU, the hierarchy of calls, and whether it is possible to modify these operators by adding extra parameters or operations.



1 Like

You could profile your code to see which kernels are used internally.
E.g. here is a simple code snippet showing the usage of nn.Conv2d:

import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity

device = "cuda"

activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
sort_by_keyword = device + "_time_total"

model = nn.Conv2d(3, 3, 3, 1, 1, device=device)
x = torch.randn(1, 3, 224, 224, device=device)

with profile(activities=activities, record_shapes=True) as prof:
    out = model(x)
print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                            aten::conv2d         0.20%     139.812us        99.99%      71.080ms      71.080ms       0.000us         0.00%      18.369us      18.369us             1  
#                                       aten::convolution         0.06%      45.846us        99.79%      70.940ms      70.940ms       0.000us         0.00%      18.369us      18.369us             1  
#                                      aten::_convolution         0.64%     456.376us        99.73%      70.894ms      70.894ms       0.000us         0.00%      18.369us      18.369us             1  
#                                 aten::cudnn_convolution        57.82%      41.100ms        83.88%      59.626ms      59.626ms      14.945us        81.36%      14.945us      14.945us             1  
# void cudnn::cnn::conv2d_grouped_direct_kernel<false,...         0.00%       0.000us         0.00%       0.000us       0.000us      12.800us        69.68%      12.800us      12.800us             1  
#                                              aten::add_         0.59%     416.511us        15.18%      10.789ms      10.789ms       3.424us        18.64%       3.424us       3.424us             1  
# void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       3.424us        18.64%       3.424us       3.424us             1  
#                                         Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us       2.145us        11.68%       2.145us       2.145us             1  
#                                      cudaGetDeviceCount         0.00%       1.994us         0.00%       1.994us       1.994us       0.000us         0.00%       0.000us       0.000us             1  
#                                    cudaDriverGetVersion         0.00%       0.170us         0.00%       0.170us       0.170us       0.000us         0.00%       0.000us       0.000us             1  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
# Self CPU time total: 71.086ms
# Self CUDA time total: 18.369us

torch.backends.cudnn.enabled = False
with profile(activities=activities, record_shapes=True) as prof:
    out = model(x)
print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                            aten::conv2d         0.01%       5.560us        99.99%      76.870ms      76.870ms             1  
#                                       aten::convolution         0.04%      31.268us        99.98%      76.865ms      76.865ms             1  
#                                      aten::_convolution        10.40%       7.996ms        99.94%      76.833ms      76.833ms             1  
#                                 aten::_nnpack_available        25.28%      19.438ms        25.28%      19.438ms      19.438ms             1  
#                                       aten::thnn_conv2d         0.02%      15.148us        64.26%      49.399ms      49.399ms             1  
#                              aten::_slow_conv2d_forward        50.91%      39.139ms        64.24%      49.384ms      49.384ms             1  
#                                             aten::empty         0.05%      38.101us         0.28%     216.997us     108.499us             2  
#                                              aten::view         0.01%       7.445us         0.01%       7.445us       3.722us             2  
#                                           aten::resize_         0.02%      14.547us         0.02%      14.547us      14.547us             1  
#                                   cudaStreamIsCapturing         0.01%       6.061us         0.01%       6.061us       6.061us             1  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
# Self CPU time total: 76.878ms

The first profile shows PyTorch is dispatching to cudnn::cnn::conv2d_grouped_direct_kernel which is a closed source implementation.
After disabling cudnn the second profile shows PyTorch is dispatching to aten::_slow_conv2d_forward, which is implemented here.

You also have the ability to implement custom ops in case this would fit your use case better.

1 Like