The most common used Rotary Positional Embedding is like:
def rotate_half(t: torch.Tensor) -> torch.Tensor:
t_1, t_2 = torch.chunk(t, 2, dim=-1)
return torch.cat((-t_2, t_1), dim=-1)
def apply_rotary_pos_emb_bshd(t: torch.Tensor, freqs: torch.Tensor):
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
t = (t * cos_) + (rotate_half(t) * sin_)
return torch.cat((t, t_pass), dim=-1)
and when I test this composite-function with torch.profiler, the output json shocked me a lot:
As one can see, there is a large ToCopyBackward that takes most of the profling durations, and here’s the tested code:
#!/usr/bin/env python
# encoding: utf-8
import torch
from torch.nn import functional as F
import time
import os
from functools import partial
import torch
import torch.distributed as dist
from torch.profiler import (
profile,
ProfilerActivity,
schedule,
)
def trace_handler(profiler, file_path, op_name):
file_path = os.path.join(file_path, f"profiling-{op_name}.trace.json")
profiler.export_chrome_trace(file_path)
def get_profiler(file_path, op_name):
warmup = 5
profile_schedule = schedule(wait=2, warmup=warmup, active=1)
profiler = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=profile_schedule,
record_shapes=True,
on_trace_ready=partial(trace_handler, file_path=file_path, op_name=op_name),
with_flops=True,
profile_memory=True,
with_modules=True,
)
return profiler
def rotate_half(t: torch.Tensor) -> torch.Tensor:
t_1, t_2 = torch.chunk(t, 2, dim=-1)
return torch.cat((-t_2, t_1), dim=-1)
def apply_rotary_pos_emb_bshd(t: torch.Tensor, freqs: torch.Tensor):
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
t = (t * cos_) + (rotate_half(t) * sin_)
return torch.cat((t, t_pass), dim=-1)
def test_ops(op_func, op_name, in_params: dict):
# for warm up
out = op_func(**in_params)
loss = out.sum()
loss.backward()
profiler = get_profiler("/workspace", op_name)
test_iters = 10
torch.cuda.synchronize()
start = time.time()
with profiler as prof:
for _ in range(test_iters):
out = op_func(**in_params)
loss = out.sum()
loss.backward()
prof.step()
torch.cuda.synchronize()
using_time = time.time() - start
print(f'{op_name} \t cost: {using_time}')
def test_rope():
max_seq_len = 4096
batch_size = 10
head_num = 32
dim = 128 * 32
dim = dim // head_num
input_shape = (max_seq_len, batch_size, head_num, dim)
input_ts = torch.randn(input_shape, dtype=torch.float32, requires_grad=True)
freqs_cis_i = torch.randn(max_seq_len, dim)
freqs_cis_4d = freqs_cis.reshape(max_seq_len, 1, 1, dim)
input_data_out_F = {
"t": input_ts.cuda(),
"freqs": freqs_cis_4d.cuda()
}
test_ops(op_func=apply_rotary_pos_emb_bshd,
op_name="rope",
in_params=input_data_out_F,
)
if __name__ == '__main__':
test_rope()
and I run it with CUDA_LAUNCH_BLOCKING=1 python test.py
, the cuda version is 12.2
Does anyone have a clue why is that happen? I tested with Conv2d and other modules or functionalities that none of them gives such output.