Can't reproduce same output with einsum

Hello. I was conducting a few experiments to see if changing the order of matrix multiplications in Linear Attention would yield the same results.

import torch

"""input and weights"""
in_original = torch.randn(1, 197, 192).to(torch.device("cpu"))
q_linaer = torch.randn(192, 192).to(torch.device("cpu"))
k_linaer = torch.randn(192, 192).to(torch.device("cpu"))
v_linear = torch.randn(192, 192).to(torch.device("cpu"))


q = torch.matmul(in_original, q_linaer)
k = torch.matmul(in_original, k_linaer)
v = torch.matmul(in_original, v_linear)

"""split"""
B, C, L = in_original.shape
q = q.reshape(B, 3, C, -1)
k = k.reshape(B, 3, C, -1)
v = v.reshape(B, 3, C, -1)

"""in_qw_kv"""
kv = torch.einsum("b h j c, b h j d -> b h c d", k, v)
q_linaer = torch.reshape(q_linaer,(B, 3, L, -1))
qwkv = torch.einsum("b h i c, b h c d -> b h i d", q_linaer, kv)
re_out = torch.einsum("b c i, b h i d -> b h c d", in_original, qwkv)
re_out = torch.reshape(re_out, (B, C, L))

"""qk-v"""
qk = torch.einsum("b h i c, b h j c -> b h i j", q, k)
mhsa = torch.einsum("b h i j, b h j d -> b h i d", qk, v)
mhsa = torch.reshape(mhsa, (B,C,L))
original_output = mhsa


"""q-kv"""
kv = torch.einsum("b h j c, b h j d -> b h c d", k, v)
qkv = torch.einsum("b h i c, b h c d -> b h i d", q, kv)
qkv = torch.reshape(qkv, (B, C, L))

print(qkv-mhsa)
print(re_out-mhsa)


tensor([[[ 0.0469, -0.0625,  0.0352,  ...,  0.0469, -0.0312,  0.1094],
         [ 0.1406,  0.0391,  0.0781,  ...,  0.0781, -0.0312, -0.0781],
         [ 0.0156,  0.0000, -0.1250,  ...,  0.0625,  0.0312,  0.0469],
         ...,
         [-0.0586,  0.0312,  0.0625,  ...,  0.0156, -0.1250, -0.0312],
         [-0.1094,  0.1250,  0.0625,  ...,  0.0625, -0.1172,  0.0625],
         [ 0.0625, -0.0312, -0.0234,  ..., -0.0469,  0.1875, -0.0703]]])
tensor([[[ 261114.0625, -860036.0000, -656980.0000,  ...,  731625.5625,
           279413.4688,  -20652.5312],
         [ 494674.3125,  -74906.0938, -359076.0000,  ...,  181925.0469,
           415182.0625,  458257.1250],
         [-115711.3594,  366081.9375,  -10439.5625,  ...,  -73224.4688,
           -25387.6250, -303460.0000],
         ...,
         [  -5849.0781,  280915.9375,  473818.0938,  ...,  -29121.1484,
           670491.7500, -327646.3438],
         [-152076.0781, 1076525.0000, -985963.1250,  ..., -353680.5000,
          -321628.1250,  271734.7812],
         [ -72602.5312,  372593.5000, -612792.2500,  ..., -144943.8438,
          -260349.6875,  465726.4062]]])

Process finished with exit code 0

Upon checking the above code, I found that qkv and mhsa have the same values, but re_out shows different values from mhsa.
What could be the issue?

torch.reshape was the problem changing it to einops.rearrange solved it