An efficient implementation of indexing two axes

I want to index a tensor x simultaneously using two axes, with shifted indices. Here is a for loop I wrote to satisfy my purpose. How can I efficiently implement it in pytorch, removing the for loop?

a,b, s1,c,t,d = 2,4,5,10,6,7
x = torch.rand(a,b,s1,c,t,d)
x_new = torch.zeros(a,b,s1,c,t,d)
for i in range(c):
    index = torch.roll(torch.arange(c), -i)[:s1] # shape: (s1,)
    tmp = x[:,:,range(s1), index,...] # shape: (a,b,s1,t,d)
    x_new[:,:,:,i,...] = tmp

The same approach from your previous question using torch.compile can also be used here:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      40.000us        78.43%      40.000us       4.000us            10  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      11.000us        21.57%      11.000us       1.100us            10  
                                        cudaMemcpyAsync        31.91%      90.000us        31.91%      90.000us       4.500us       0.000us         0.00%       0.000us       0.000us            20  
                                  cudaStreamSynchronize        28.01%      79.000us        28.01%      79.000us       3.950us       0.000us         0.00%       0.000us       0.000us            20  
                       Memcpy HtoD (Pageable -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us            20  
                                       cudaLaunchKernel        38.30%     108.000us        38.30%     108.000us       5.400us       0.000us         0.00%       0.000us       0.000us            20  
                                  cudaDeviceSynchronize         1.77%       5.000us         1.77%       5.000us       5.000us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 282.000us
Self CUDA time total: 51.000us

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
      triton__0d1d2d3d4de         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us       100.00%       1.000us       1.000us             1  
           cuLaunchKernel        99.63%       2.147ms        99.63%       2.147ms       2.147ms       0.000us         0.00%       0.000us       0.000us             1  
    cudaDeviceSynchronize         0.37%       8.000us         0.37%       8.000us       8.000us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.155ms
Self CUDA time total: 1.000us
1 Like