Weighted Batch Matrix Multiplication


v1 : BatchSize x MaxSentenceLength x EmbSize
v2 : BatchSize x EmbSize x MaxSentenceLength


v3 = torch.bmm(v1, v2) : BatchSize x MaxSentenceLength x MaxSentenceLength

As I understand - torch.bmm based on simple dot product of 2 two embeddings:

                               sum(emb1[i] * emb2[i],  i = [0..EmbSize])

But for me interesting version of torch.bmm based on weighted dot product of 2 embeddings:

                               sum(w[i] * emb1[i] * emb2[i],  i = [0..EmbSize])

where w[i] are learned weights.

Is there a way to do this ?

Thanks in advance!

If I understand correctly, w is learnable parameter with size EmbSize.
What if you can do the following?

BatchSize = 10
MaxSentenceLength = 20
EmbSize = 30

v1 = torch.randn(BatchSize, MaxSentenceLength, EmbSize)
v2 = torch.randn(BatchSize, EmbSize, MaxSentenceLength)
w = nn.Parameter(torch.randn(1,1,EmbSize))

out = torch.bmm(w * v1, v2)