Torch.roll different shifts according to another dimension index

Given a torch.Tensor w of shape (b,c,d). I want to find an efficient implementation of the following:

w_rolled = torch.zeros_like(w)
for j in range(b):
   w_rolled[j]=torch.roll(w[j], j, dims=-2)

I am not sure if we can use built-in pytorch functions to avoid the ugly for loop, and make this roll operation more efficient?

You could fuse the loop with torch.compile:

import torch
from torch.profiler import profile, record_function, ProfilerActivity

def fun(w_rolled, w):
    for j in range(b):
        w_rolled[j]=torch.roll(w[j], j, dims=-2)


b, c, d = 64, 64, 64
device = "cuda"

w = torch.randn(b, c, d, device=device)
w_rolled = torch.zeros_like(w)

# warmup
for _ in range(10):
    fun(w_rolled, w)

with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
    fun(w_rolled, w)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
# void at::native::roll_cuda_kernel<float>(float const...         0.00%       0.000us         0.00%       0.000us       0.000us      66.000us        50.77%      66.000us       1.031us            64  
#                          Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      64.000us        49.23%      64.000us       1.000us            64  
#                                        cudaLaunchKernel        43.66%     210.000us        43.66%     210.000us       3.281us       0.000us         0.00%       0.000us       0.000us            64  
#                                         cudaMemcpyAsync        55.51%     267.000us        55.51%     267.000us       4.172us       0.000us         0.00%       0.000us       0.000us            64  
#                                   cudaDeviceSynchronize         0.83%       4.000us         0.83%       4.000us       4.000us       0.000us         0.00%       0.000us       0.000us             1  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
# Self CPU time total: 481.000us
# Self CUDA time total: 130.000us

fun_compiled = torch.compile(fun)

for _ in range(10):
    fun_compiled(w_rolled, w)
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
    fun_compiled(w_rolled, w)
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# -------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                      Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# -------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#         triton__0d1d2d3de         0.00%       0.000us         0.00%       0.000us       0.000us      29.000us       100.00%      29.000us      29.000us             1  
#            cuLaunchKernel        57.14%      12.000us        57.14%      12.000us      12.000us       0.000us         0.00%       0.000us       0.000us             1  
#     cudaDeviceSynchronize        42.86%       9.000us        42.86%       9.000us       9.000us       0.000us         0.00%       0.000us       0.000us             1  
# -------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
# Self CPU time total: 21.000us
# Self CUDA time total: 29.000us
1 Like

Thanks! A follow-up question: What if w is a complex tensor? Does torch.compile work for it? As I get the following warning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.