I have multiple tensors mixed with batched matrices and broadcasted matrices. Is there an efficient and elegant way to get the matrix product of them?
Following snippet is how it’s implemented currently which is ugly. BTW, for loop is better avoided due to performance reason.
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
tensor3 = torch.randn(10, 5, 6)
tensor4 = torch.randn(6, 7)
tensor5 = torch.randn(10, 7, 8)
result = torch.matmul(
torch.matmul(
torch.matmul(
torch.matmul(
tensor1,
tensor2),
tensor3),
tensor4),
tensor5,
)
print(result.size()) # torch.Size([10, 3, 8])
Hi @MrCrHaM,
You can get the desired output via using torch.func.vmap
and torch.linalg.multi_dot
, here’s a minimal reproducible example.
import torch
from torch.linalg import multi_dot
from torch.func import vmap
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
tensor3 = torch.randn(10, 5, 6)
tensor4 = torch.randn(6, 7)
tensor5 = torch.randn(10, 7, 8)
def func(A,B,C,D,E):
return multi_dot((A,B,C,D,E)) #notice the 2nd ()'s here.
out = vmap(func, in_dims=(0,None,0,None,0))(tensor1,tensor2,tensor3,tensor4,tensor5)
print(out.shape) #returns torch.Size([10, 3, 8])
1 Like
Hi @AlphaBetaGamma96,
Thanks a lot for the great answer.
However, I’m currently using PyTorch 1.12.1. It seems that torch.func
is not available, and upgrading to 2.0 would involve uprading CUDA and NVIDIA drive accordingly which is not an option for the cluster for now.
Any ideas to solve it without any PyTorch 2.0 specific features?
You can change torch.func.vmap
to functorch.vmap
, the functorch
package should have been installed alongside pytorch
when you initially installed it.
Why would this be the case? The PyTorch 2.0.0 binaries ship with CUDA 11.7 and 11.8 and all needed dependencies. You would this only need to install a proper NVIDIA driver, which support minor-version compatibility since CUDA 11. Were you using a CUDA 10.2 build?
Sadly, it seems functorch
package is not available for PyTorch 1.12.1 as well.
>>> import functorch
ModuleNotFoundError: No module named 'functorch'
The cluster I’m running code on is using CUDA 10.2 and according to the administrator upgrading CUDA is not planned for now unfortunately. So I guess PyTorch 2.0 specific features won’t be available for me anytime soon. 
Ah OK, that explains it. Major CUDA versions need a driver update (unless you are using the compat driver package on data center GPUs). Note that PyTorch does not support CUDA 10.2 anymore, so you would eventually need to update if you want to use any newer release.
Hi @MrCrHaM,
I read the functorch
docs, for your case, you can install functorch
via pip install functorch
, more info in the docs here: Install functorch — functorch nightly documentation
You should be able to install with conda too if you need to, by just locally installing to pip inside your conda environment.
Also, you will get warned that you’re using a depracted package and to use torch.func
instead, just make sure to suspress the warnings with the warnings
package (should be installed already) and things should be fine
Thanks for the info. I’ll send these words to our cluster administrator. 
Thanks for the further instruction. Module functorch
works like a charm.