e.g.
a.shape = (1024, 768) and b.shape = (768, 8, 512)
how can i get the results which shape is 1024, 8, 512?
I have tried the operation @, torch.bmm, torch.mm, torch.matmul, but it seems that all above dons’t work in this stuation.
Thanks!
You can use torch.einsum
. It works like magic.
import torch
a = torch.rand((1024, 768))
b = torch.rand((768, 8, 512))
c = torch.einsum('ij, jkl -> ikl', [a, b])
print(c.shape)
1 Like