Hello. I was conducting a few experiments to see if changing the order of matrix multiplications in Linear Attention would yield the same results.
import torch
"""input and weights"""
in_original = torch.randn(1, 197, 192).to(torch.device("cpu"))
q_linaer = torch.randn(192, 192).to(torch.device("cpu"))
k_linaer = torch.randn(192, 192).to(torch.device("cpu"))
v_linear = torch.randn(192, 192).to(torch.device("cpu"))
q = torch.matmul(in_original, q_linaer)
k = torch.matmul(in_original, k_linaer)
v = torch.matmul(in_original, v_linear)
"""split"""
B, C, L = in_original.shape
q = q.reshape(B, 3, C, -1)
k = k.reshape(B, 3, C, -1)
v = v.reshape(B, 3, C, -1)
"""in_qw_kv"""
kv = torch.einsum("b h j c, b h j d -> b h c d", k, v)
q_linaer = torch.reshape(q_linaer,(B, 3, L, -1))
qwkv = torch.einsum("b h i c, b h c d -> b h i d", q_linaer, kv)
re_out = torch.einsum("b c i, b h i d -> b h c d", in_original, qwkv)
re_out = torch.reshape(re_out, (B, C, L))
"""qk-v"""
qk = torch.einsum("b h i c, b h j c -> b h i j", q, k)
mhsa = torch.einsum("b h i j, b h j d -> b h i d", qk, v)
mhsa = torch.reshape(mhsa, (B,C,L))
original_output = mhsa
"""q-kv"""
kv = torch.einsum("b h j c, b h j d -> b h c d", k, v)
qkv = torch.einsum("b h i c, b h c d -> b h i d", q, kv)
qkv = torch.reshape(qkv, (B, C, L))
print(qkv-mhsa)
print(re_out-mhsa)
tensor([[[ 0.0469, -0.0625, 0.0352, ..., 0.0469, -0.0312, 0.1094],
[ 0.1406, 0.0391, 0.0781, ..., 0.0781, -0.0312, -0.0781],
[ 0.0156, 0.0000, -0.1250, ..., 0.0625, 0.0312, 0.0469],
...,
[-0.0586, 0.0312, 0.0625, ..., 0.0156, -0.1250, -0.0312],
[-0.1094, 0.1250, 0.0625, ..., 0.0625, -0.1172, 0.0625],
[ 0.0625, -0.0312, -0.0234, ..., -0.0469, 0.1875, -0.0703]]])
tensor([[[ 261114.0625, -860036.0000, -656980.0000, ..., 731625.5625,
279413.4688, -20652.5312],
[ 494674.3125, -74906.0938, -359076.0000, ..., 181925.0469,
415182.0625, 458257.1250],
[-115711.3594, 366081.9375, -10439.5625, ..., -73224.4688,
-25387.6250, -303460.0000],
...,
[ -5849.0781, 280915.9375, 473818.0938, ..., -29121.1484,
670491.7500, -327646.3438],
[-152076.0781, 1076525.0000, -985963.1250, ..., -353680.5000,
-321628.1250, 271734.7812],
[ -72602.5312, 372593.5000, -612792.2500, ..., -144943.8438,
-260349.6875, 465726.4062]]])
Process finished with exit code 0
Upon checking the above code, I found that qkv
and mhsa
have the same values, but re_out
shows different values from mhsa
.
What could be the issue?