Given a torch.Tensor w of shape (b,c,d). I want to find an efficient implementation of the following:
w_rolled = torch.zeros_like(w)
for j in range(b):
w_rolled[j]=torch.roll(w[j], j, dims=-2)
I am not sure if we can use built-in pytorch functions to avoid the ugly for loop, and make this roll operation more efficient?
You could fuse the loop with torch.compile
:
import torch
from torch.profiler import profile, record_function, ProfilerActivity
def fun(w_rolled, w):
for j in range(b):
w_rolled[j]=torch.roll(w[j], j, dims=-2)
b, c, d = 64, 64, 64
device = "cuda"
w = torch.randn(b, c, d, device=device)
w_rolled = torch.zeros_like(w)
# warmup
for _ in range(10):
fun(w_rolled, w)
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
fun(w_rolled, w)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# void at::native::roll_cuda_kernel<float>(float const... 0.00% 0.000us 0.00% 0.000us 0.000us 66.000us 50.77% 66.000us 1.031us 64
# Memcpy DtoD (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 64.000us 49.23% 64.000us 1.000us 64
# cudaLaunchKernel 43.66% 210.000us 43.66% 210.000us 3.281us 0.000us 0.00% 0.000us 0.000us 64
# cudaMemcpyAsync 55.51% 267.000us 55.51% 267.000us 4.172us 0.000us 0.00% 0.000us 0.000us 64
# cudaDeviceSynchronize 0.83% 4.000us 0.83% 4.000us 4.000us 0.000us 0.00% 0.000us 0.000us 1
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# Self CPU time total: 481.000us
# Self CUDA time total: 130.000us
fun_compiled = torch.compile(fun)
for _ in range(10):
fun_compiled(w_rolled, w)
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
fun_compiled(w_rolled, w)
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
# ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# triton__0d1d2d3de 0.00% 0.000us 0.00% 0.000us 0.000us 29.000us 100.00% 29.000us 29.000us 1
# cuLaunchKernel 57.14% 12.000us 57.14% 12.000us 12.000us 0.000us 0.00% 0.000us 0.000us 1
# cudaDeviceSynchronize 42.86% 9.000us 42.86% 9.000us 9.000us 0.000us 0.00% 0.000us 0.000us 1
# ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# Self CPU time total: 21.000us
# Self CUDA time total: 29.000us
1 Like
Thanks! A follow-up question: What if w
is a complex tensor? Does torch.compile
work for it? As I get the following warning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.