Einsum result does not align when I try to merge two matmul together

Hi,

My code is like this:

inten = torch.randn(16, 16, 224, 224)
out13 = torch.randn(16, 8, 224, 224)
outid = torch.cat([inten, out13], dim=1)
wd = torch.randn(32, 16)
wc = torch.randn(32, 8)
wcat = torch.cat([wd, wc], dim=1)
out_wd = torch.einsum('abcd,eb->aecd', inten, wd)
out_wc = torch.einsum('abcd,eb->aecd', out13, wc)
out_cat = torch.einsum('abcd,eb->aecd', outid, wcat)
print((out_wd + out_wc - out_cat).abs().sum())

I would like to compute einsum with two pairs of tensor and sum up the result, and I feel that I can concat the two pairs and compute with one einsum, but the above code shows a great difference. How could I get over this pleaseļ¼Ÿ