Composite RoPE backward gives a large ToCopyBackward0 in profiling trace

The most common used Rotary Positional Embedding is like:

def rotate_half(t: torch.Tensor) -> torch.Tensor:
    t_1, t_2 = torch.chunk(t, 2, dim=-1)
    return torch.cat((-t_2, t_1), dim=-1)

def apply_rotary_pos_emb_bshd(t: torch.Tensor, freqs: torch.Tensor):
    rot_dim = freqs.shape[-1]

    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)

    t = (t * cos_) + (rotate_half(t) * sin_)

    return torch.cat((t, t_pass), dim=-1)

and when I test this composite-function with torch.profiler, the output json shocked me a lot:

As one can see, there is a large ToCopyBackward that takes most of the profling durations, and here’s the tested code:

#!/usr/bin/env python
# encoding: utf-8
import torch
from torch.nn import functional as F
import time
import os
from functools import partial
import torch
import torch.distributed as dist
from torch.profiler import (
    profile,
    ProfilerActivity,
    schedule,
)


def trace_handler(profiler, file_path, op_name):
    file_path = os.path.join(file_path, f"profiling-{op_name}.trace.json")
    profiler.export_chrome_trace(file_path)


def get_profiler(file_path, op_name):
    warmup = 5
    profile_schedule = schedule(wait=2, warmup=warmup, active=1)
    profiler = profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        schedule=profile_schedule,
        record_shapes=True,
        on_trace_ready=partial(trace_handler, file_path=file_path, op_name=op_name),
        with_flops=True,
        profile_memory=True,
        with_modules=True,
    )
    return profiler

def rotate_half(t: torch.Tensor) -> torch.Tensor:
    t_1, t_2 = torch.chunk(t, 2, dim=-1)
    return torch.cat((-t_2, t_1), dim=-1)


def apply_rotary_pos_emb_bshd(t: torch.Tensor, freqs: torch.Tensor):
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)

    t = (t * cos_) + (rotate_half(t) * sin_)

    return torch.cat((t, t_pass), dim=-1)

def test_ops(op_func, op_name, in_params: dict):
    # for warm up
    out = op_func(**in_params)
    loss = out.sum()
    loss.backward()

    profiler = get_profiler("/workspace", op_name)
    test_iters = 10
    torch.cuda.synchronize()
    start = time.time()
    with profiler as prof:
        for _ in range(test_iters):
            out = op_func(**in_params)
            loss = out.sum()
            loss.backward()
            prof.step()

    torch.cuda.synchronize()
    using_time = time.time() - start
    print(f'{op_name} \t cost: {using_time}')


def test_rope():
    max_seq_len = 4096
    batch_size = 10
    head_num = 32
    dim = 128 * 32
    dim = dim // head_num
    input_shape = (max_seq_len, batch_size, head_num, dim)

    input_ts = torch.randn(input_shape, dtype=torch.float32, requires_grad=True)

    freqs_cis_i = torch.randn(max_seq_len, dim)
    freqs_cis_4d = freqs_cis.reshape(max_seq_len, 1, 1, dim)

    input_data_out_F = {
        "t": input_ts.cuda(),
        "freqs": freqs_cis_4d.cuda()
    }
    test_ops(op_func=apply_rotary_pos_emb_bshd,
        op_name="rope",
        in_params=input_data_out_F,
    )


if __name__ == '__main__':
    test_rope()

and I run it with CUDA_LAUNCH_BLOCKING=1 python test.py, the cuda version is 12.2

Does anyone have a clue why is that happen? I tested with Conv2d and other modules or functionalities that none of them gives such output.

The ToCopyBackward itself is not taking all the time. Its triggering a device to host synchronization, and its waiting for all the queued up operations on device to finish.

 input_data_out_F = {
        "t": input_ts.cuda(),
        "freqs": freqs_cis_4d.cuda()
    }

If you want to avoid this synchronization, you can initialize your input_ts on cuda to begin with, and then you should not need these device movements.

Thanks! That solves the problem.