Hello,

v1 : BatchSize x MaxSentenceLength x EmbSize

v2 : BatchSize x EmbSize x MaxSentenceLength

and

v3 = torch.bmm(v1, v2) : BatchSize x MaxSentenceLength x MaxSentenceLength

As I understand - torch.bmm based on simple dot product of 2 two embeddings:

```
sum(emb1[i] * emb2[i], i = [0..EmbSize])
```

But for me interesting version of torch.bmm based on weighted dot product of 2 embeddings:

```
sum(w[i] * emb1[i] * emb2[i], i = [0..EmbSize])
```

where w[i] are learned weights.

Is there a way to do this ?

Thanks in advance!