CUDA devise-side assert triggered by trying to move a tensor to CPU

I’m running into some trouble running modelling_llama from transformers (transformers/ at main · huggingface/transformers · GitHub):
that’s the functions I’m using now:

def rotate_half(x):
    """Rotates half the hidden dims of the input."""

    if DEBUG:
        print("rotate_half was just called...")
        #print(f"it received x {x}")
        print(f"it received x of shape {x.shape} and placed on {x.device}")
        input("go on?")

    if DEBUG:
        print("calculating half a shape element...")
        half_a_shape = x.shape[-1] // 2
        input("go on?")

    if DEBUG:
        print("forming x1...")
        input("go on?")
    x1 = x[..., : half_a_shape]

    if DEBUG:
        print("forming x2...")
        input("go on?")
    x2 = x[..., half_a_shape :]

    if DEBUG:
        print(f"we now have x1 of shape {x1.shape} and x2 of shape {x2.shape}")
        print("concating -x2 and x1 along dim -1...")
        input("go on?")

    # rotated_x =, x1), dim=-1)   

    if DEBUG:

    #return rotated_x

    return x1, x2

def magic_concat(x1, x2):
    if DEBUG:
        print("let's try some on-CPU magic...")

    temp_device = x1.device
    our_cpu = torch.device("cpu")
    x1 =
    x2 =
    if DEBUG:
        print("forming the magic list")
    magic_list = (-x2 , x1)

    if DEBUG:
        print("our dim is -1")
    dim = -1

    if DEBUG:
        print("and now actual concating...")
    thingy =, dim=dim)

    if DEBUG:
        print("...and now back to the original device...")

    thingy =    

    if DEBUG:
    return thingy

def apply_rotary_pos_emb(q, 
    if DEBUG:
        print(f"apply_rotary_pos_emb has received q {q}, \nk {k}, \ncos {cos}, \nsin {sin} and \nposition_ids {position_ids} ")

    if DEBUG:
        print("calculating gather indices...")
    gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]
    gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
    if DEBUG:
        print("calculating cos and sin...")

    cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
    sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
    if DEBUG:
        print("calculating q_embed...")
        print(f"q is on device {q.device}, \ncos is on {cos.device}, sin is on {sin.device}")

    if DEBUG:
        print("let's take this one step at a time...")
        print("rotating half of q...")
        # thingy = rotate_half(q)
        x1, x2 = rotate_half(q)
        #thingy =,x1), dim = -1)
        thingy = magic_concat(x1,x2)
        input("go on?")
        print("calculating the first part of q_embed...")
        first_part = q * cos
        input("go on?")
        print("calculating the second part of q_embed...")
        second_part = thingy * sin
        input("go on?")
        print ("adding the parts. shouldn't cause trouble...")
    q_embed = first_part + second_part

    # q_embed = (q * cos) + (rotate_half(q) * sin)
    if DEBUG:
        print("calculating k_embed...")
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

and the error I’m getting is this:

File "/root/anaconda3/envs/new2_bud_llm_finetune/lib/python3.10/site-packages/transformers/models/llama/", line 200, in magic_concat
    x1 =
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

The reason for such a weird form of the functions is that their original form triggered another assert error: turns out I can’t concat (-x2, x1), and i can’t even create the (-x2, x1) list for some reason.

You are running into a valid CUDA assert and should rerun the coda via CUDA_LAUNCH_BLOCKING=1 as suggested in the error message to get a proper stacktrace pointing to the failing operation.
Once the assert is triggered the CUDA context is corrupt and following CUDA operations will fail.