I can’t seem to find a way to do the following operation with one simple matrix multiplication. Given
a = torch.randn((2,768)) b= torch.randn((2,4,768))
I want to multiply the first row of a, a, with the matrix b. Similarly, I would like to get the product a*b. The total result should be a (2,4) tensor.
I know how to get this results using einsum, however I need to do this operation also with nn.Bilinear (aWb, with W being the weight matrix): how should I shape the nn.Bilinear weight matrix to get the result that I want?
Thanks for your time.