3D matrix of dot products

Hi! I have two matrices for which I need to calculate the dot product, but only for one dimension. They are of the same shape (N,M,D) and I want to take the dot product of the last dimension D so that the result is of shape (N,M,1). How would I do that? Most solutions seem to assume stuff about number of dimensions etc.

I think you can use einsum see below for a small example. It felt a bit confusing doing dot product over multiple dimensions but it seems to make sense to me after a few examples.

N, M, D = 2, 3, 4
x = torch.ones((N,M,D))
y = torch.ones((N,M,D))
torch.einsum('ijk,ijk->ij', x, y).unsqueeze(2) # will be shape (N, M, 1)

That actually does seem to work! Thanks :smiley: