Hi,
I am having trouble with calculating a specific matrix product.
Lets say I have a single tensor A of size (N,m,m) and i want to caclulate the matrix product of the N tensors sized (m,m).
So for example if N=3, i want A[0,:,:]@A[1,:,:]@A[2,:,:]. Is there a way to do this without using a for loop?
You could use torch.linalg.multi_dot — PyTorch 2.0 documentation
>>> import torch
>>> a = torch.stack([torch.eye(3) * 2, torch.eye(3) * 3, torch.eye(3) * 5])
>>> a
tensor([[[2., 0., 0.],
[0., 2., 0.],
[0., 0., 2.]],
[[3., 0., 0.],
[0., 3., 0.],
[0., 0., 3.]],
[[5., 0., 0.],
[0., 5., 0.],
[0., 0., 5.]]])
>>> a.shape
torch.Size([3, 3, 3])
>>> torch.linalg.multi_dot(a.unbind(0))
tensor([[30., 0., 0.],
[ 0., 30., 0.],
[ 0., 0., 30.]])
>>>
Thanks a lot! It seems the thing I was missing was the .unbind() operation