Vector times matrix simple question


I can’t seem to find a way to do the following operation with one simple matrix multiplication. Given

a = torch.randn((2,768))
b= torch.randn((2,4,768))

I want to multiply the first row of a, a[0], with the matrix b[0]. Similarly, I would like to get the product a[1]*b[1]. The total result should be a (2,4) tensor.

I know how to get this results using einsum, however I need to do this operation also with nn.Bilinear (aWb, with W being the weight matrix): how should I shape the nn.Bilinear weight matrix to get the result that I want?

Thanks for your time.