Try
batch_size, seq_len, embedding = 3, 4, 5
a = torch.randn([batch_size, seq_len, embedding])
b = torch.randn([batch_size, 1, embedding])
print((a * b).shape) # torch.Size([3, 4, 5])
print(torch.bmm(a, b.permute(0, 2, 1)).shape) # torch.Size([3, 4, 1])