How can I do element-wise batch matrix multiplication?

I solved like this, but maybe I am late to the party:

def bmul(vec, mat, axis=0):
    mat = mat.transpose(axis, -1)
    return (mat * vec.expand_as(mat)).transpose(axis, -1)
1 Like