An efficient implementation of indexing two axes

The same approach from your previous question using torch.compile can also be used here:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      40.000us        78.43%      40.000us       4.000us            10  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      11.000us        21.57%      11.000us       1.100us            10  
                                        cudaMemcpyAsync        31.91%      90.000us        31.91%      90.000us       4.500us       0.000us         0.00%       0.000us       0.000us            20  
                                  cudaStreamSynchronize        28.01%      79.000us        28.01%      79.000us       3.950us       0.000us         0.00%       0.000us       0.000us            20  
                       Memcpy HtoD (Pageable -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us            20  
                                       cudaLaunchKernel        38.30%     108.000us        38.30%     108.000us       5.400us       0.000us         0.00%       0.000us       0.000us            20  
                                  cudaDeviceSynchronize         1.77%       5.000us         1.77%       5.000us       5.000us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 282.000us
Self CUDA time total: 51.000us

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
      triton__0d1d2d3d4de         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us       100.00%       1.000us       1.000us             1  
           cuLaunchKernel        99.63%       2.147ms        99.63%       2.147ms       2.147ms       0.000us         0.00%       0.000us       0.000us             1  
    cudaDeviceSynchronize         0.37%       8.000us         0.37%       8.000us       8.000us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.155ms
Self CUDA time total: 1.000us
1 Like