Non bmm batch matrix product operation

In exploring a word2vec variant, I’d like to find a way to do the following bmm like operation which bmm does not seem suited for. So far I can only write it with a for-loop. Is there a way to do this with pytorch tensor operations?

def combine(ivectors, ovectors):
    (batch_sz, sense_sz, context_sz, embedding_sz) = ovectors.shape
    LL = torch.zeros((batch_sz, sense_sz, context_sz))
    for i in range(batch_sz):
        LL[i,:,:] = torch.matmul(ovectors[i,:,:,:], ivectors[i,:].T)
    return LL

My problem seems to boil down to this: I want to take an input tensor of shape, say, 2x2x3x5 (ovectors) and combine it with a tensor of shape 5x2 (ivectors.T) and output a tensor of shape (2,2,3). What bmm and matmul seem to be designed for is a reduction that gives (2,2,3,2), 24 elements, but I want the reduction in the above for-loop, which gives only 12.

Is there a way to get rid of the for loop?

einsum can express any reductions, perhaps what you want looks:

torch.einsum("bsce,be->bsc", ovectors, ivectors)

Matmul for this reduction can be done with shapes (2,2,3,5) @ (2,1,5,1) = (2,2,3,1) (this does bmm with two first dimensions joined&broadcasted)

torch.matmul(ovectors, ivectors.reshape(batch_sz,1,embedding_sz,1)).squeeze(-1)

Thanks. That works and your explanation is very helpful!