Row-Wise Dot Product

Suppose I have two tensors:

a = torch.randn(10, 1000, 1, 4)
b = torch.randn(10, 1000, 6, 4)

Where the third index is the index of a vector.

I want to take the dot product between each vector in b with respect to the vector in a.

To illustrate, this is what I mean:

dots = torch.Tensor(10, 1000, 6, 1)
for b in range(10):
     for c in range(1000):
           for v in range(6):
            dots[b,c,v] = torch.dot(b[b,c,v], a[b,c,0]) 

How would I achieve this using torch functions?

This code should work:

a = torch.randn(10, 1000, 1, 4)
b = torch.randn(10, 1000, 6, 4)

dots = torch.Tensor(10, 1000, 6, 1)
for x in range(10):
     for y in range(1000):
         for z in range(6):
             dots[x,y,z] = torch.dot(b[x,y,z], a[x,y,0]) 


ret = torch.matmul(b, a.permute(0, 1, 3, 2))
print((ret - dots).abs().max())
> tensor(9.5367e-07)