Hi:
a=[batch_size, seq_len, embedding]
b=[batch_size, 1, embedding]
how to get the dot product of [1, embedding] and [1,embedding], [2,embedding],…
[seq,embedding]
Hi:
a=[batch_size, seq_len, embedding]
b=[batch_size, 1, embedding]
how to get the dot product of [1, embedding] and [1,embedding], [2,embedding],…
[seq,embedding]
Try
batch_size, seq_len, embedding = 3, 4, 5
a = torch.randn([batch_size, seq_len, embedding])
b = torch.randn([batch_size, 1, embedding])
print((a * b).shape) # torch.Size([3, 4, 5])
print(torch.bmm(a, b.permute(0, 2, 1)).shape) # torch.Size([3, 4, 1])