Provided 3 tensors A, B, C of shape (N, N) on CUDA device, I have a simple function which does
A * B + C
To understand torch.compile I have tried to run and profile this code with and without torch.compile
Without torch.compile, the code runs as expected, i.e. there are 2 kernel launches one for element-wise addition (temp = A * B) and another for element-wise multiplication (temp + C)
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
test compile 0.00% 0.000us 0.00% 0.000us 0.000us 46.816ms 12199.90% 46.816ms 46.816ms 1
test compile 1.30% 1.350ms 99.98% 103.550ms 103.550ms 0.000us 0.00% 383.738us 383.738us 1
aten::add 1.67% 1.734ms 42.17% 43.679ms 43.679ms 194.109us 50.58% 194.109us 194.109us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 194.109us 50.58% 194.109us 194.109us 1
aten::mul 16.62% 17.213ms 56.50% 58.522ms 58.522ms 189.629us 49.42% 189.629us 189.629us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 189.629us 49.42% 189.629us 189.629us 1
cudaLaunchKernel 80.38% 83.253ms 80.38% 83.253ms 41.626ms 0.000us 0.00% 0.000us 0.000us 2
cudaDeviceSynchronize 0.02% 19.208us 0.02% 19.208us 19.208us 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 103.570ms
Self CUDA time total: 383.738us
With torch.compile, when I compile and run for the first time I get these statistics
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
compile_fx.<locals>.fw_compiler_base (dynamo_timed) 0.00% 0.000us 0.00% 0.000us 0.000us 133.232ms 164244.17% 133.232ms 133.232ms 1
test compile 0.87% 20.587ms 100.00% 2.357s 2.357s 0.000us 0.00% 81.118us 81.118us 1
_compile.compile_inner (dynamo_timed) 5.44% 128.167ms 99.11% 2.336s 2.336s 0.000us 0.00% 81.118us 81.118us 1
OutputGraph.call_user_compiler (dynamo_timed) 80.97% 1.908s 93.03% 2.192s 2.192s 0.000us 0.00% 81.118us 81.118us 1
create_aot_dispatcher_function (dynamo_timed) 0.84% 19.879ms 12.06% 284.118ms 284.118ms 0.000us 0.00% 81.118us 81.118us 1
aten::copy_ 0.66% 15.654ms 0.71% 16.787ms 258.256us 81.118us 100.00% 81.118us 1.248us 65
compile_fx.<locals>.fw_compiler_base (dynamo_timed) 6.84% 161.163ms 9.72% 229.105ms 229.105ms 0.000us 0.00% 81.118us 81.118us 1
aten::clone 0.05% 1.138ms 0.56% 13.224ms 322.525us 0.000us 0.00% 64.255us 1.567us 41
aten::lift_fresh 0.79% 18.681ms 0.93% 21.862ms 295.427us 0.000us 0.00% 64.255us 0.868us 74
Memcpy DtoD (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 64.255us 79.21% 64.255us 2.677us 24
aten::to 0.24% 5.714ms 0.67% 15.848ms 417.043us 0.000us 0.00% 16.863us 0.444us 38
aten::_to_copy 0.17% 4.092ms 0.43% 10.133ms 422.215us 0.000us 0.00% 16.863us 0.703us 24
Memcpy HtoD (Pageable -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 16.863us 20.79% 16.863us 0.703us 24
TorchDynamo Cache Lookup 0.00% 3.039us 0.00% 3.039us 0.760us 0.000us 0.00% 0.000us 0.000us 4
cudaStreamIsCapturing 0.00% 25.294us 0.00% 25.294us 3.613us 0.000us 0.00% 0.000us 0.000us 7
aten::empty_strided 0.33% 7.718ms 0.33% 7.718ms 26.432us 0.000us 0.00% 0.000us 0.000us 292
aten::detach 0.02% 582.685us 0.31% 7.316ms 25.316us 0.000us 0.00% 0.000us 0.000us 289
detach 0.14% 3.265ms 0.26% 6.163ms 106.254us 0.000us 0.00% 0.000us 0.000us 58
aten::empty 1.07% 25.186ms 1.17% 27.655ms 119.720us 0.000us 0.00% 0.000us 0.000us 231
cudaMalloc 0.01% 253.339us 0.01% 253.339us 253.339us 0.000us 0.00% 0.000us 0.000us 1
aten::mul 0.25% 5.918ms 0.72% 16.917ms 3.383ms 0.000us 0.00% 0.000us 0.000us 5
prims::mul 0.14% 3.220ms 0.22% 5.136ms 5.136ms 0.000us 0.00% 0.000us 0.000us 1
aten::empty_permuted 0.08% 1.797ms 0.08% 1.995ms 665.058us 0.000us 0.00% 0.000us 0.000us 3
aten::as_strided 0.06% 1.371ms 0.06% 1.459ms 182.417us 0.000us 0.00% 0.000us 0.000us 8
aten::add 0.14% 3.378ms 0.18% 4.324ms 864.771us 0.000us 0.00% 0.000us 0.000us 5
prims::add 0.01% 170.167us 0.01% 212.438us 212.438us 0.000us 0.00% 0.000us 0.000us 1
aten::detach_ 0.03% 725.800us 0.03% 760.008us 20.541us 0.000us 0.00% 0.000us 0.000us 37
detach_ 0.00% 34.208us 0.00% 34.208us 0.925us 0.000us 0.00% 0.000us 0.000us 37
aten::alias 0.01% 331.017us 0.02% 570.771us 570.771us 0.000us 0.00% 0.000us 0.000us 1
aten::view 0.01% 150.991us 0.01% 239.754us 239.754us 0.000us 0.00% 0.000us 0.000us 1
prims::view_of 0.00% 74.287us 0.00% 88.763us 88.763us 0.000us 0.00% 0.000us 0.000us 1
aten::set_ 0.09% 2.227ms 0.09% 2.227ms 222.736us 0.000us 0.00% 0.000us 0.000us 10
cudaMemcpyAsync 0.04% 957.462us 0.04% 957.462us 19.947us 0.000us 0.00% 0.000us 0.000us 48
cudaStreamSynchronize 0.01% 174.836us 0.01% 174.836us 7.285us 0.000us 0.00% 0.000us 0.000us 24
aten::randperm 0.20% 4.798ms 0.47% 11.176ms 2.235ms 0.000us 0.00% 0.000us 0.000us 5
aten::slice 0.25% 5.887ms 0.40% 9.515ms 2.379ms 0.000us 0.00% 0.000us 0.000us 4
aten::index_add 0.01% 297.864us 0.07% 1.638ms 1.638ms 0.000us 0.00% 0.000us 0.000us 1
aten::index_put 0.05% 1.204ms 0.10% 2.468ms 822.563us 0.000us 0.00% 0.000us 0.000us 3
aten::empty_like 0.00% 99.503us 0.01% 145.478us 48.493us 0.000us 0.00% 0.000us 0.000us 3
aten::index 0.07% 1.663ms 0.22% 5.269ms 2.634ms 0.000us 0.00% 0.000us 0.000us 2
aten::new_empty 0.04% 844.291us 0.04% 941.363us 470.681us 0.000us 0.00% 0.000us 0.000us 2
compile_fx_inner (dynamo_timed) 0.04% 836.516us 0.04% 836.516us 836.516us 0.000us 0.00% 0.000us 0.000us 1
Torch-Compiled Region 0.01% 262.690us 0.01% 263.364us 263.364us 0.000us 0.00% 0.000us 0.000us 1
cudaDeviceSynchronize 0.00% 14.106us 0.00% 14.106us 14.106us 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.357s
Self CUDA time total: 81.118us
Now if I create new matrices A, B, C and run the profiling again I get
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
test compile 0.60% 1.547ms 99.99% 259.521ms 259.521ms 1
TorchDynamo Cache Lookup 0.01% 23.106us 0.01% 23.106us 23.106us 1
_compile.compile_inner (dynamo_timed) 66.16% 171.720ms 99.36% 257.887ms 257.887ms 1
cudaStreamIsCapturing 0.01% 25.576us 0.01% 25.576us 4.263us 6
aten::empty_strided 4.11% 10.656ms 4.11% 10.656ms 313.399us 34
aten::detach 0.08% 214.091us 0.54% 1.391ms 27.824us 50
detach 0.43% 1.110ms 0.45% 1.177ms 42.039us 28
aten::mul 10.17% 26.409ms 11.04% 28.659ms 7.165ms 4
aten::empty 0.83% 2.145ms 0.83% 2.145ms 102.150us 21
aten::add 7.09% 18.409ms 7.98% 20.708ms 5.177ms 4
OutputGraph.call_user_compiler (dynamo_timed) 1.42% 3.679ms 16.59% 43.057ms 43.057ms 1
create_aot_dispatcher_function (dynamo_timed) 5.77% 14.983ms 14.67% 38.071ms 38.071ms 1
aten::clone 0.08% 220.023us 0.23% 601.647us 35.391us 17
aten::copy_ 0.08% 195.953us 0.08% 195.953us 11.527us 17
aten::to 0.03% 86.083us 0.03% 86.083us 6.622us 13
aten::lift_fresh 2.26% 5.855ms 2.57% 6.676ms 256.776us 26
aten::detach_ 0.08% 209.552us 0.09% 221.823us 17.063us 13
detach_ 0.00% 12.271us 0.00% 12.271us 0.944us 13
compile_fx.<locals>.fw_compiler_base (dynamo_timed) 0.44% 1.150ms 0.75% 1.950ms 1.950ms 1
compile_fx_inner (dynamo_timed) 0.31% 799.775us 0.31% 799.775us 799.775us 1
aten::empty_like 0.00% 7.427us 0.01% 18.140us 9.070us 2
Torch-Compiled Region 0.02% 64.157us 0.02% 64.157us 64.157us 1
cudaDeviceSynchronize 0.01% 29.074us 0.01% 29.074us 29.074us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 259.550ms
In both cases one thing I observe is that aten::add, aten::mul and prims::add and prims::mul all happen on CPU, even though the martices are on GPU.
Am I misunderstanding something from these profiling results?
Here is the code to reproduce:
import torch
import torch.profiler as profiler
n = 16000
A = torch.randn((n, n), device='cuda')
B = torch.randn((n, n), device='cuda')
C = torch.randn((n, n), device='cuda')
D = torch.zeros((n, n), device='cuda')
def test():
D = A * B + C
test_c = torch.compile(test, fullgraph=True, mode="max-autotune")
with torch.profiler.profile() as prof:
with torch.profiler.record_function("test compile"):
test_c()
prof.export_chrome_trace("trace_compile.json")
print(prof.key_averages().table(sort_by="cuda_time_total"))