I want to index a tensor x
simultaneously using two axes, with shifted indices. Here is a for loop I wrote to satisfy my purpose. How can I efficiently implement it in pytorch, removing the for loop?
a,b, s1,c,t,d = 2,4,5,10,6,7
x = torch.rand(a,b,s1,c,t,d)
x_new = torch.zeros(a,b,s1,c,t,d)
for i in range(c):
index = torch.roll(torch.arange(c), -i)[:s1] # shape: (s1,)
tmp = x[:,:,range(s1), index,...] # shape: (a,b,s1,t,d)
x_new[:,:,:,i,...] = tmp
The same approach from your previous question using torch.compile
can also be used here:
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
void at::native::index_elementwise_kernel<128, 4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 40.000us 78.43% 40.000us 4.000us 10
void at::native::elementwise_kernel<128, 2, at::nati... 0.00% 0.000us 0.00% 0.000us 0.000us 11.000us 21.57% 11.000us 1.100us 10
cudaMemcpyAsync 31.91% 90.000us 31.91% 90.000us 4.500us 0.000us 0.00% 0.000us 0.000us 20
cudaStreamSynchronize 28.01% 79.000us 28.01% 79.000us 3.950us 0.000us 0.00% 0.000us 0.000us 20
Memcpy HtoD (Pageable -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 0.000us 0.00% 0.000us 0.000us 20
cudaLaunchKernel 38.30% 108.000us 38.30% 108.000us 5.400us 0.000us 0.00% 0.000us 0.000us 20
cudaDeviceSynchronize 1.77% 5.000us 1.77% 5.000us 5.000us 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 282.000us
Self CUDA time total: 51.000us
------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
triton__0d1d2d3d4de 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 100.00% 1.000us 1.000us 1
cuLaunchKernel 99.63% 2.147ms 99.63% 2.147ms 2.147ms 0.000us 0.00% 0.000us 0.000us 1
cudaDeviceSynchronize 0.37% 8.000us 0.37% 8.000us 8.000us 0.000us 0.00% 0.000us 0.000us 1
------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.155ms
Self CUDA time total: 1.000us
1 Like