Matrix-vector multiplication (4D batched data)

Hi guys,

I need to compute matrix-vector multiplication in batch data (4D),
e.g. [99, 6, 5, 128] mult [99, 6, 1, 128] to get [99, 6, 5]
I tried to do it with matmult, but it did not work, the best I could do was to:
loop dimensions 0 and 1, and then:
compute torch.mv([5, 128], [128])

Could you please suggest how to do that without loops?
Thanks in advance.

How about (A * B).sum(-1)?

Wow, so easy. Thanks!