I’m running into some trouble running modelling_llama
from transformers
(transformers/modeling_llama.py at main · huggingface/transformers · GitHub):
that’s the functions I’m using now:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
if DEBUG:
print("rotate_half was just called...")
#print(f"it received x {x}")
print(f"it received x of shape {x.shape} and placed on {x.device}")
input("go on?")
if DEBUG:
print("calculating half a shape element...")
half_a_shape = x.shape[-1] // 2
input("go on?")
if DEBUG:
print("forming x1...")
input("go on?")
x1 = x[..., : half_a_shape]
if DEBUG:
print("forming x2...")
input("go on?")
x2 = x[..., half_a_shape :]
if DEBUG:
print(f"we now have x1 of shape {x1.shape} and x2 of shape {x2.shape}")
print("concating -x2 and x1 along dim -1...")
input("go on?")
# rotated_x = torch.cat((-x2, x1), dim=-1)
if DEBUG:
print("returning")
input("ok?")
#return rotated_x
return x1, x2
def magic_concat(x1, x2):
if DEBUG:
print("let's try some on-CPU magic...")
temp_device = x1.device
our_cpu = torch.device("cpu")
x1 = x1.to(our_cpu)
x2 = x2.to(our_cpu)
if DEBUG:
print("forming the magic list")
magic_list = (-x2 , x1)
if DEBUG:
print("our dim is -1")
dim = -1
if DEBUG:
print("and now actual concating...")
thingy = torch.cat(magic_list, dim=dim)
if DEBUG:
print("...and now back to the original device...")
thingy = thingy.to(temp_device)
if DEBUG:
print("returning...")
return thingy
def apply_rotary_pos_emb(q,
k,
cos,
sin,
position_ids):
if DEBUG:
print(f"apply_rotary_pos_emb has received q {q}, \nk {k}, \ncos {cos}, \nsin {sin} and \nposition_ids {position_ids} ")
if DEBUG:
print("calculating gather indices...")
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
if DEBUG:
print("calculating cos and sin...")
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
if DEBUG:
print("calculating q_embed...")
print(f"q is on device {q.device}, \ncos is on {cos.device}, sin is on {sin.device}")
if DEBUG:
print("let's take this one step at a time...")
print("rotating half of q...")
# thingy = rotate_half(q)
x1, x2 = rotate_half(q)
#thingy = torch.cat((-x2,x1), dim = -1)
thingy = magic_concat(x1,x2)
input("go on?")
print("calculating the first part of q_embed...")
first_part = q * cos
input("go on?")
print("calculating the second part of q_embed...")
second_part = thingy * sin
input("go on?")
print ("adding the parts. shouldn't cause trouble...")
q_embed = first_part + second_part
# q_embed = (q * cos) + (rotate_half(q) * sin)
if DEBUG:
print("calculating k_embed...")
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
and the error I’m getting is this:
...
File "/root/anaconda3/envs/new2_bud_llm_finetune/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 200, in magic_concat
x1 = x1.to(our_cpu)
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
The reason for such a weird form of the functions is that their original form triggered another assert error: turns out I can’t concat (-x2, x1)
, and i can’t even create the (-x2, x1)
list for some reason.