How do do matrix multiplication (matmal) along certain axis?

I am relative new to pytorch. After doing a pretty exhaustive search online, I still couldn’t obtain the operation I want.

My question is How do do matrix multiplication (matmal) along certain axis?

For example, if I want to multiply a vector by a matrix, that would just be the following:
a = torch.rand(3,5)
b = torch.rand(3)
torch.matmul(b,a)

One can interpret this as each element in b scale each row of a, and summing those scaled row together.

What if we have the dimension of a and b as following:
a = torch.rand(3,5,10)
b = torch.rand(3,10)
and we want to do matrix multiplication along the first axis, basically, here’s what I want to do in a for loop form:

product = []
for i in range(10):
a_i = a[:,:,i]
b_i = b[:,i]
a_i_mul_b_i = torch.matmul(b_i,a_i)
product.append(a_i_mul_b_i)

Although product is not a tensor but a list of tensor, but if we ignore the datatype, this is the end result I want.

Hello Zeyuyun!

The general-purpose tool for taking a product of (contracting) multiple
tensors along various axes is torch.einsum() (named after “Einstein
summation”).

(You can also fiddle with the dimensions to get them to line up as
needed and use matmul() or bmm().)

Here is a script that compares your loop code to einsum() (as well
as to bmm() and matmul()):

import torch
torch.__version__
torch.random.manual_seed (2020)

a = torch.rand(3,5,10)
b = torch.rand(3,10)

product = []
for i in range(10):
    a_i = a[:,:,i]
    b_i = b[:,i]
    a_i_mul_b_i = torch.matmul(b_i,a_i)
    product.append(a_i_mul_b_i)

# make product a tensor
product = torch.stack (product)

einsum_prod = torch.einsum ('ijk, ik -> kj', a, b)
print ('check einsum_prod ...\n', torch.eq (einsum_prod, product).all())

matmul_prod = torch.matmul (b.unsqueeze (1).transpose (0, 2), a.permute (2, 0, 1)).squeeze()
print ('check matmul_prod ...\n', torch.eq (matmul_prod, product).all())

bmm_prod = torch.bmm (b.unsqueeze (1).transpose (0, 2), a.permute (2, 0, 1)).squeeze()
print ('check bmm_prod ...\n', torch.eq (bmm_prod, product).all())

print ('product = ...\n', product)

Here is its output

>>> import torch
>>> torch.__version__
'1.6.0'
>>> torch.random.manual_seed (2020)
<torch._C.Generator object at 0x7f8997a636f0>
>>>
>>> a = torch.rand(3,5,10)
>>> b = torch.rand(3,10)
>>>
>>> product = []
>>> for i in range(10):
...     a_i = a[:,:,i]
...     b_i = b[:,i]
...     a_i_mul_b_i = torch.matmul(b_i,a_i)
...     product.append(a_i_mul_b_i)
...
>>> # make product a tensor
>>> product = torch.stack (product)
>>>
>>> einsum_prod = torch.einsum ('ijk, ik -> kj', a, b)
>>> print ('check einsum_prod ...\n', torch.eq (einsum_prod, product).all())
check einsum_prod ...
 tensor(True)
>>>
>>> matmul_prod = torch.matmul (b.unsqueeze (1).transpose (0, 2), a.permute (2, 0, 1)).squeeze()
>>> print ('check matmul_prod ...\n', torch.eq (matmul_prod, product).all())
check matmul_prod ...
 tensor(True)
>>>
>>> bmm_prod = torch.bmm (b.unsqueeze (1).transpose (0, 2), a.permute (2, 0, 1)).squeeze()
>>> print ('check bmm_prod ...\n', torch.eq (bmm_prod, product).all())
check bmm_prod ...
 tensor(True)
>>>
>>> print ('product = ...\n', product)
product = ...
 tensor([[0.8785, 0.5736, 0.3109, 0.5423, 0.6269],
        [0.7494, 1.4107, 0.9018, 1.1483, 1.2441],
        [0.9368, 0.9044, 0.9225, 0.9634, 0.1672],
        [0.1231, 0.4003, 0.4123, 0.1015, 0.4601],
        [0.9209, 1.1959, 0.5452, 1.0301, 0.8842],
        [0.2887, 1.1165, 0.9888, 0.8110, 0.4526],
        [0.9565, 1.2474, 0.1234, 1.6425, 0.9914],
        [1.1996, 0.8205, 1.0448, 1.3298, 1.0197],
        [0.3437, 0.3698, 0.4044, 0.4140, 0.8048],
        [0.7642, 1.4165, 0.7622, 0.6675, 1.2127]])
>>>

Best.

K. Frank

4 Likes

Thank you so much!! That’s exactly what I want. I think the key I am missing is “unsqueeze”, but eigenstein summation is another really good interpretation!