Hello, I would like to know if it is possible to modify the operations implemented by PyTorch on the GPU that are used in pre-trained models like ResNet50 (resnet50 — Torchvision 0.20 documentation).
For example, I would like to know where the source code for operations like convolution or batch normalization is located when PyTorch performs inference on the GPU, the hierarchy of calls, and whether it is possible to modify these operators by adding extra parameters or operations.
You could profile your code to see which kernels are used internally.
E.g. here is a simple code snippet showing the usage of nn.Conv2d:
import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity
device = "cuda"
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
sort_by_keyword = device + "_time_total"
model = nn.Conv2d(3, 3, 3, 1, 1, device=device)
x = torch.randn(1, 3, 224, 224, device=device)
with profile(activities=activities, record_shapes=True) as prof:
out = model(x)
print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# aten::conv2d 0.20% 139.812us 99.99% 71.080ms 71.080ms 0.000us 0.00% 18.369us 18.369us 1
# aten::convolution 0.06% 45.846us 99.79% 70.940ms 70.940ms 0.000us 0.00% 18.369us 18.369us 1
# aten::_convolution 0.64% 456.376us 99.73% 70.894ms 70.894ms 0.000us 0.00% 18.369us 18.369us 1
# aten::cudnn_convolution 57.82% 41.100ms 83.88% 59.626ms 59.626ms 14.945us 81.36% 14.945us 14.945us 1
# void cudnn::cnn::conv2d_grouped_direct_kernel<false,... 0.00% 0.000us 0.00% 0.000us 0.000us 12.800us 69.68% 12.800us 12.800us 1
# aten::add_ 0.59% 416.511us 15.18% 10.789ms 10.789ms 3.424us 18.64% 3.424us 3.424us 1
# void at::native::elementwise_kernel<128, 2, at::nati... 0.00% 0.000us 0.00% 0.000us 0.000us 3.424us 18.64% 3.424us 3.424us 1
# Memset (Device) 0.00% 0.000us 0.00% 0.000us 0.000us 2.145us 11.68% 2.145us 2.145us 1
# cudaGetDeviceCount 0.00% 1.994us 0.00% 1.994us 1.994us 0.000us 0.00% 0.000us 0.000us 1
# cudaDriverGetVersion 0.00% 0.170us 0.00% 0.170us 0.170us 0.000us 0.00% 0.000us 0.000us 1
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# Self CPU time total: 71.086ms
# Self CUDA time total: 18.369us
torch.backends.cudnn.enabled = False
with profile(activities=activities, record_shapes=True) as prof:
out = model(x)
print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
# Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
# aten::conv2d 0.01% 5.560us 99.99% 76.870ms 76.870ms 1
# aten::convolution 0.04% 31.268us 99.98% 76.865ms 76.865ms 1
# aten::_convolution 10.40% 7.996ms 99.94% 76.833ms 76.833ms 1
# aten::_nnpack_available 25.28% 19.438ms 25.28% 19.438ms 19.438ms 1
# aten::thnn_conv2d 0.02% 15.148us 64.26% 49.399ms 49.399ms 1
# aten::_slow_conv2d_forward 50.91% 39.139ms 64.24% 49.384ms 49.384ms 1
# aten::empty 0.05% 38.101us 0.28% 216.997us 108.499us 2
# aten::view 0.01% 7.445us 0.01% 7.445us 3.722us 2
# aten::resize_ 0.02% 14.547us 0.02% 14.547us 14.547us 1
# cudaStreamIsCapturing 0.01% 6.061us 0.01% 6.061us 6.061us 1
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
# Self CPU time total: 76.878ms
The first profile shows PyTorch is dispatching to cudnn::cnn::conv2d_grouped_direct_kernel which is a closed source implementation.
After disabling cudnn the second profile shows PyTorch is dispatching to aten::_slow_conv2d_forward, which is implemented here.
You also have the ability to implement custom ops in case this would fit your use case better.