Hi, I would like to ask a question about torch.mv()
and torch.mm()
. In cuda, they can achieve very high performance in us. However, the first time call them will cost a long time incudaFree
in seconds. Could everyone give me some idea to avoid this?
as example:
a = torch.empty(4, 4, device=torch.device('cuda'))
torch.mm(a, a)
and PyTorch Profiler shown me as follows:
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
prof 0.02% 338.000us 100.00% 2.164s 2.164s 0.000us 0.00% 6.000us 6.000us 1
aten::mm 0.54% 11.660ms 99.98% 2.164s 2.164s 6.000us 100.00% 6.000us 6.000us 1
cudaFree 99.24% 2.148s 99.24% 2.148s 1.074s 0.000us 0.00% 0.000us 0.000us 2
cudaMalloc 0.20% 4.293ms 0.20% 4.293ms 1.431ms 0.000us 0.00% 0.000us 0.000us 3
cudaLaunchKernel 0.00% 66.000us 0.00% 66.000us 66.000us 0.000us 0.00% 0.000us 0.000us 1
aten::zeros 0.00% 45.000us 0.00% 58.000us 58.000us 0.000us 0.00% 0.000us 0.000us 1
cudaDeviceSynchronize 0.00% 42.000us 0.00% 42.000us 42.000us 0.000us 0.00% 0.000us 0.000us 1
aten::empty 0.00% 13.000us 0.00% 13.000us 6.500us 0.000us 0.00% 0.000us 0.000us 2
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla... 0.00% 7.000us 0.00% 7.000us 7.000us 0.000us 0.00% 0.000us 0.000us 1
cudaGetSymbolAddress 0.00% 2.000us 0.00% 2.000us 2.000us 0.000us 0.00% 0.000us 0.000us 1
aten::zero_ 0.00% 1.000us 0.00% 1.000us 1.000us 0.000us 0.00% 0.000us 0.000us 1
cudaDeviceGetAttribute 0.00% 1.000us 0.00% 1.000us 0.071us 0.000us 0.00% 0.000us 0.000us 14
void gemmSN_NN_kernel<float, 256, 4, 2, 8, 4, 4, fal... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 100.00% 6.000us 6.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.164s
Self CUDA time total: 6.000us
How to shortcut the CPU time? Why would torch.mv()
call cudaFree
, which is metioned in documentation that we should avoid to use though?