v1 : BatchSize x MaxSentenceLength x EmbSize
v2 : BatchSize x EmbSize x MaxSentenceLength
v3 = torch.bmm(v1, v2) : BatchSize x MaxSentenceLength x MaxSentenceLength
As I understand - torch.bmm based on simple dot product of 2 two embeddings:
sum(emb1[i] * emb2[i], i = [0..EmbSize])
But for me interesting version of torch.bmm based on weighted dot product of 2 embeddings:
sum(w[i] * emb1[i] * emb2[i], i = [0..EmbSize])
where w[i] are learned weights.
Is there a way to do this ?
Thanks in advance!