Replacing torch.zeros internals with cudaMemset instead of fill kernel

I’ve noticed that a fill kernel is called when constructing a tensor on the GPU with torch.zeros. Would it be possible and/or preferable to use cudaMemset instead? It seems like the more intuitive and efficient choice to me. If this is not suitable, why?

torch.zeros first calls cudaMalloc to allocate memory on the device and then calls fill_kernel_cuda. From my understanding, it should be possible to just replace fill_kernel_cuda with a suitable implementation calling cudaMemset.

Are you seeing any benefits in using cudaMemset instead of the TensorIterator approach we are using, i.e. did you compare both approaches?

Sorry for the late response. I have done some basic profiling using the PyTorch Profiler. It looks like changing the implementation to using cudaMemset does improve the performance of torch.zeros.

I profiled the creation of a fp16 zeros tensor of shape (1024, 1024, 1024) on the GPU as shown below:

import torch
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 
             record_shapes=True, 
             profile_memory=True, 
             use_cuda=True) as prof:
    with record_function("torch_zeros"):
        zero_tensor = torch.zeros((1024, 1024, 1024),
                                  dtype=torch.float16,
                                  pin_memory=False,
                                  device='cuda')

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

I got the following results for the current implementation:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_zeros         6.35%      19.267ms        99.19%     300.836ms     300.836ms       0.000us         0.00%       2.798ms       2.798ms           0 b           0 b       2.00 Gb           0 b             1  
                                            aten::zeros         0.04%     130.736us        92.18%     279.572ms     279.572ms       0.000us         0.00%       2.798ms       2.798ms           0 b           0 b       2.00 Gb           0 b             1  
                                            aten::zero_         0.01%      38.260us         2.17%       6.594ms       6.594ms       0.000us         0.00%       2.798ms       2.798ms           0 b           0 b           0 b           0 b             1  
                                            aten::fill_         0.02%      57.585us         2.16%       6.556ms       6.556ms       2.798ms       100.00%       2.798ms       2.798ms           0 b           0 b           0 b           0 b             1  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.798ms       100.00%       2.798ms       2.798ms           0 b           0 b           0 b           0 b             1  
                                            torch_zeros         0.00%       0.000us         0.00%       0.000us       0.000us       2.798ms       100.00%       2.798ms       2.798ms           0 b           0 b           0 b           0 b             1  
                                     cudaGetDeviceCount         0.00%       1.407us         0.00%       1.407us       0.703us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             2  
                             cudaGetDeviceProperties_v2         0.66%       1.996ms         0.66%       1.996ms       1.996ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                            aten::empty         0.04%     108.655us        89.96%     272.847ms     272.847ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b       2.00 Gb       2.00 Gb             1  
                       cudaDeviceGetStreamPriorityRange        89.84%     272.472ms        89.84%     272.472ms     272.472ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 303.297ms
Self CUDA time total: 2.798ms

And the following results if I swap the fill kernel with a cudaMemset call:

------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         torch_zeros         6.61%      19.356ms        99.57%     291.549ms     291.549ms       0.000us         0.00%       1.579ms       1.579ms           0 b           0 b       2.00 Gb           0 b             1  
                         aten::zeros         0.05%     133.432us        92.29%     270.246ms     270.246ms       0.000us         0.00%       1.579ms       1.579ms           0 b           0 b       2.00 Gb           0 b             1  
                         aten::zero_         0.01%      38.235us         0.04%     102.532us     102.532us       0.000us         0.00%       1.579ms       1.579ms           0 b           0 b           0 b           0 b             1  
                         aten::fill_         0.01%      39.738us         0.02%      64.297us      64.297us       1.579ms       100.00%       1.579ms       1.579ms           0 b           0 b           0 b           0 b             1  
                     Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us       1.579ms       100.00%       1.579ms       1.579ms           0 b           0 b           0 b           0 b             1  
                         torch_zeros         0.00%       0.000us         0.00%       0.000us       0.000us       1.579ms       100.00%       1.579ms       1.579ms           0 b           0 b           0 b           0 b             1  
                  cudaGetDeviceCount         0.00%       1.285us         0.00%       1.285us       0.643us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             2  
          cudaGetDeviceProperties_v2         0.66%       1.946ms         0.66%       1.946ms       1.946ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                         aten::empty         0.04%     120.588us        92.21%     270.010ms     270.010ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b       2.00 Gb       2.00 Gb             1  
    cudaDeviceGetStreamPriorityRange        92.08%     269.620ms        92.08%     269.620ms     269.620ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 292.815ms
Self CUDA time total: 1.579ms

The time saved is not very significant, but it is still faster and in my opinion, using cudaMemset seems a lot more straightforward and intuitive here.