How the perform matrix multiply between a 2dim tensor and a 3dim tensor?

e.g.
a.shape = (1024, 768) and b.shape = (768, 8, 512)
how can i get the results which shape is 1024, 8, 512?
I have tried the operation @, torch.bmm, torch.mm, torch.matmul, but it seems that all above dons’t work in this stuation.
Thanks!

You can use torch.einsum. It works like magic.

import torch

a = torch.rand((1024, 768))
b = torch.rand((768, 8, 512))

c = torch.einsum('ij, jkl -> ikl', [a, b])

print(c.shape)
1 Like