a.shape = (1024, 768) and b.shape = (768, 8, 512)
how can i get the results which shape is 1024, 8, 512?
I have tried the operation @, torch.bmm, torch.mm, torch.matmul, but it seems that all above dons’t work in this stuation.
You can use
torch.einsum. It works like magic.
import torch a = torch.rand((1024, 768)) b = torch.rand((768, 8, 512)) c = torch.einsum('ij, jkl -> ikl', [a, b]) print(c.shape)