Hello.
I’m developing a simple Python application for personal use that detects the fundamental frequency of two harmonic sounds simultaneously. To achieve this, I’ve created and trained a small PyTorch neural network. I’m going to use it on a machine without GPU, so the code blow is executed on CPU.
Network(
(conv1): DepthwiseConvBlock(
(depth_conv): Conv1d(1, 16, kernel_size=(11,), stride=(1,), padding=same)
(point_conv): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
(nonlin): GELU(approximate='none')
(norm): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
(pool): MaxPool2d(kernel_size=[1, 2], stride=[1, 2], padding=0, dilation=1, ceil_mode=False)
)
(conv2): DepthwiseConvBlock(
(depth_conv): Conv1d(16, 16, kernel_size=(7,), stride=(1,), padding=same, groups=16)
(point_conv): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
(nonlin): GELU(approximate='none')
(norm): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
(pool): MaxPool2d(kernel_size=[1, 2], stride=[1, 2], padding=0, dilation=1, ceil_mode=False)
)
(conv3): DepthwiseConvBlock(
(depth_conv): Conv1d(16, 8, kernel_size=(5,), stride=(1,), padding=same, groups=8)
(point_conv): Conv1d(8, 8, kernel_size=(1,), stride=(1,))
(nonlin): GELU(approximate='none')
(norm): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
(pool): MaxPool2d(kernel_size=[1, 2], stride=[1, 2], padding=0, dilation=1, ceil_mode=False)
)
(flat): Flatten(start_dim=1, end_dim=-1)
(linear): Linear(in_features=96, out_features=8, bias=True)
)
The input to this network is a spectrum calculated using the CQT transform function operating on NumPy arrays audio. The network, in turn, produces probabilities for the detected notes. However, there’s an issue: when I run the network in a loop alongside the CQT transform function, the network’s inference speed experiences a significant drop.
Just CQT transform function:
def profile_cqt_to_torch():
cqts = []
for a in audio:
cqt = cqt_rough(a)
cqts.append(torch.from_numpy(cqt).unsqueeze(0).to(torch.float))
return cqts
t = time()
cqts = profile_cqt_to_torch()
print(time() - t)
>>> 0.35244059562683105
Just the network with precalculated CQTs:
network.eval()
def profile_net():
preds = []
for cqt in cqts:
with torch.no_grad():
pred = network(cqt)
preds.append(pred)
return pred
t = time()
preds = profile_net()
print(time() - t)
>>> [W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.
>>> 0.5710155963897705
Combined:
network.eval()
def profile_cqt_net():
preds = []
for a in audio:
cqt = cqt_rough(a)
with torch.no_grad():
pred = network(torch.from_numpy(cqt).unsqueeze(0).to(torch.float))
preds.append(pred)
return preds
t = time()
preds = profile_cqt_net()
print(time() - t)
>>> 15.845890522003174
By utilizing the new torch.compile
method with mode = reduce-overhead
, I’ve managed to decrease the execution time to approximately 10 seconds. Nevertheless, this duration still seems excessively long.
cProfile
and torch.profiler
gives a similar results. I check dtype
and size
of input tensor every time.
Profilng with precalculated CQTs:
Code:
def profile_net():
preds = []
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
for cqt in cqts:
assert cqt.shape == torch.Size([1, 96])
assert cqt.dtype == torch.float32
with record_function("model_inference"):
with torch.no_grad():
pred = network(cqt)
preds.append(pred)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
return pred
cProfile results:
92544 function calls (72544 primitive calls) in 0.850 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
2 0.000 0.000 0.850 0.425 /home/user/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3406(run_code)
2 0.000 0.000 0.850 0.425 {built-in method builtins.exec}
1 0.000 0.000 0.849 0.849 <ipython-input-24-33789e8bbc84>:1(<module>)
1 0.007 0.007 0.849 0.849 <ipython-input-23-24b737ab1132>:6(profile_net)
10500/500 0.023 0.000 0.828 0.002 /home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1514(_wrapped_call_impl)
10500/500 0.044 0.000 0.827 0.002 /home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520(_call_impl)
500 0.020 0.000 0.823 0.002 /mnt/3660917E7771EA5C/Programming/PitchDetect/TheModel.py:71(forward)
1500 0.049 0.000 0.720 0.000 /mnt/3660917E7771EA5C/Programming/PitchDetect/TheModel.py:49(forward)
3000 0.016 0.000 0.378 0.000 /home/user/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:309(forward)
3000 0.008 0.000 0.356 0.000 /home/user/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:301(_conv_forward)
3000 0.348 0.000 0.348 0.000 {built-in method torch.conv1d}
1500 0.003 0.000 0.087 0.000 /home/user/.local/lib/python3.10/site-packages/torch/nn/modules/activation.py:681(forward)
1500 0.084 0.000 0.084 0.000 {built-in method torch._C._nn.gelu}
1500 0.010 0.000 0.068 0.000 /home/user/.local/lib/python3.10/site-packages/torch/nn/modules/normalization.py:195(forward)
1500 0.006 0.000 0.060 0.000 /home/user/.local/lib/python3.10/site-packages/torch/nn/modules/pooling.py:165(forward)
1500 0.008 0.000 0.055 0.000 /home/user/.local/lib/python3.10/site-packages/torch/nn/functional.py:2528(layer_norm)
1500 0.005 0.000 0.055 0.000 /home/user/.local/lib/python3.10/site-packages/torch/_jit_internal.py:478(fn)
1500 0.004 0.000 0.050 0.000 /home/user/.local/lib/python3.10/site-packages/torch/nn/functional.py:769(_max_pool2d)
1500 0.045 0.000 0.045 0.000 {built-in method torch.max_pool2d}
1500 0.044 0.000 0.044 0.000 {built-in method torch.layer_norm}
pytorch.profiler results:
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
model_inference 33.48% 264.789ms 99.68% 788.410ms 1.577ms 500
aten::conv1d 2.02% 16.011ms 43.88% 347.084ms 115.695us 3000
aten::convolution 1.94% 15.362ms 41.53% 328.489ms 109.496us 3000
aten::_convolution 5.33% 42.175ms 39.61% 313.289ms 104.430us 3000
aten::_convolution_mode 1.38% 10.923ms 29.09% 230.100ms 153.400us 1500
aten::mkldnn_convolution 18.30% 144.708ms 19.21% 151.955ms 151.955us 1000
aten::thnn_conv2d 0.71% 5.634ms 10.87% 85.951ms 42.975us 2000
aten::_slow_conv2d_forward 6.81% 53.884ms 10.21% 80.788ms 40.394us 2000
aten::gelu 6.63% 52.461ms 6.63% 52.461ms 34.974us 1500
aten::layer_norm 0.64% 5.027ms 5.72% 45.237ms 30.158us 1500
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 790.968ms
With both functions combined:
Code:
def profile_cqt_net():
preds = []
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
for a in audio:
cqt = cqt_rough(a)
cqt = torch.from_numpy(cqt).unsqueeze(0).to(torch.float)
assert cqt.shape == torch.Size([1, 96])
assert cqt.dtype == torch.float32
with torch.no_grad():
with record_function("model_inference"):
pred = network(cqt)
preds.append(pred)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
return preds
109039 function calls (107539 primitive calls) in 11.783 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
2 0.000 0.000 11.783 5.891 /home/user/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3406(run_code)
2 0.000 0.000 11.783 5.891 {built-in method builtins.exec}
1 0.001 0.001 11.782 11.782 <ipython-input-16-b93e40bc29de>:1(<module>)
1 0.027 0.027 11.782 11.782 <ipython-input-13-9f753a9ee13c>:7(profile_cqt_net)
1000/500 0.005 0.000 10.409 0.021 /home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1514(_wrapped_call_impl)
1000/500 0.015 0.000 10.407 0.021 /home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520(_call_impl)
1000/500 0.020 0.000 10.397 0.021 /home/user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:307(_fn)
500 0.003 0.000 10.333 0.021 /mnt/3660917E7771EA5C/Programming/PitchDetect/TheModel.py:71(forward)
500 0.002 0.000 10.311 0.021 /home/user/.local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:15(inner)
500 0.003 0.000 10.308 0.021 /home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:3901(forward)
500 0.002 0.000 10.305 0.021 /home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1481(g)
500 0.010 0.000 10.303 0.021 /home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:2519(runtime_wrapper)
500 0.007 0.000 10.292 0.021 /home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1498(call_func_with_args)
500 0.002 0.000 10.282 0.021 /home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1583(rng_functionalization_wrapper)
500 0.002 0.000 10.280 0.021 /home/user/.local/lib/python3.10/site-packages/torch/_inductor/codecache.py:373(__call__)
500 0.002 0.000 10.277 0.021 /home/user/.local/lib/python3.10/site-packages/torch/_inductor/codecache.py:385(_run_from_cache)
500 0.169 0.000 10.275 0.021 /tmp/torchinductor_user/fy/cfyr3q7akk2fssiuzxfhjdricxsnhqemn66ccqvr6xnodgbjmeai.py:297(call)
3000 9.927 0.003 9.927 0.003 {built-in method torch.convolution}
500 0.010 0.000 1.275 0.003 /mnt/3660917E7771EA5C/Programming/PitchDetect/cqt_rough.py:77(__call__)
1500 1.154 0.001 1.242 0.001 /mnt/3660917E7771EA5C/Programming/PitchDetect/cqt_rough.py:74(apply_filters)
1500 0.006 0.000 0.088 0.000 {method 'mean' of 'numpy.ndarray' objects}
1500 0.033 0.000 0.082 0.000 /home/user/.local/lib/python3.10/site-packages/numpy/core/_methods.py:163(_mean)
2500 0.010 0.000 0.069 0.000 /home/user/.local/lib/python3.10/site-packages/torch/_ops.py:687(__call__)
2500 0.059 0.000 0.059 0.000 {built-in method torch._ops.inductor._reinterpret_tensor}
3500 0.054 0.000 0.054 0.000 {built-in method torch.empty_strided}
500 0.036 0.000 0.036 0.000 {built-in method torch.addmm}
1500 0.024 0.000 0.024 0.000 {method 'reduce' of 'numpy.ufunc' objects}
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
model_inference 2.63% 643.239ms 99.60% 24.327s 48.653ms 500
aten::conv1d 0.38% 92.808ms 40.14% 9.805s 3.268ms 3000
aten::convolution 0.18% 42.919ms 39.91% 9.748s 3.249ms 3000
aten::_convolution 0.41% 100.469ms 39.73% 9.705s 3.235ms 3000
aten::_convolution_mode 0.11% 27.077ms 38.94% 9.512s 6.341ms 1500
aten::mkldnn_convolution 29.10% 7.108s 29.22% 7.137s 7.137ms 1000
aten::gelu 28.33% 6.920s 28.33% 6.920s 4.614ms 1500
aten::layer_norm 0.07% 17.180ms 27.44% 6.701s 4.467ms 1500
aten::native_layer_norm 27.26% 6.658s 27.36% 6.684s 4.456ms 1500
aten::thnn_conv2d 0.08% 19.653ms 9.72% 2.375s 1.187ms 2000
aten::_slow_conv2d_forward 9.33% 2.280s 9.67% 2.363s 1.181ms 2000
aten::unsqueeze 0.33% 79.823ms 0.41% 100.883ms 13.451us 7500
aten::linear 0.08% 20.467ms 0.34% 82.141ms 164.282us 500
aten::max_pool2d 0.08% 18.410ms 0.29% 71.995ms 47.997us 1500
aten::to 0.04% 8.668ms 0.22% 54.208ms 108.416us 500
aten::copy_ 0.22% 54.167ms 0.22% 54.167ms 18.056us 3000
aten::max_pool2d_with_indices 0.22% 53.585ms 0.22% 53.585ms 35.723us 1500
aten::_to_copy 0.11% 25.727ms 0.19% 46.555ms 93.110us 500
aten::view 0.17% 42.428ms 0.17% 42.428ms 4.243us 10000
aten::addmm 0.11% 27.823ms 0.17% 41.023ms 82.046us 500
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 24.425s
OS: Ubuntu 22.04
PyTorch: 2.1.0
I have no clue how to fix this problem and hoping for help.