# Chained broadcasting matmul without for loop

I have multiple tensors mixed with batched matrices and broadcasted matrices. Is there an efficient and elegant way to get the matrix product of them?

Following snippet is how it’s implemented currently which is ugly. BTW, for loop is better avoided due to performance reason.

``````tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
tensor3 = torch.randn(10, 5, 6)
tensor4 = torch.randn(6, 7)
tensor5 = torch.randn(10, 7, 8)

result = torch.matmul(
torch.matmul(
torch.matmul(
torch.matmul(
tensor1,
tensor2),
tensor3),
tensor4),
tensor5,
)
print(result.size())  # torch.Size([10, 3, 8])
``````

Hi @MrCrHaM,

You can get the desired output via using `torch.func.vmap` and `torch.linalg.multi_dot`, here’s a minimal reproducible example.

``````import torch
from torch.linalg import multi_dot
from torch.func import vmap

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
tensor3 = torch.randn(10, 5, 6)
tensor4 = torch.randn(6, 7)
tensor5 = torch.randn(10, 7, 8)

def func(A,B,C,D,E):
return multi_dot((A,B,C,D,E)) #notice the 2nd ()'s here.

out = vmap(func, in_dims=(0,None,0,None,0))(tensor1,tensor2,tensor3,tensor4,tensor5)
print(out.shape) #returns torch.Size([10, 3, 8])
``````
1 Like

Thanks a lot for the great answer.

However, I’m currently using PyTorch 1.12.1. It seems that `torch.func` is not available, and upgrading to 2.0 would involve uprading CUDA and NVIDIA drive accordingly which is not an option for the cluster for now.

Any ideas to solve it without any PyTorch 2.0 specific features?

You can change `torch.func.vmap` to `functorch.vmap`, the `functorch` package should have been installed alongside `pytorch` when you initially installed it.

Why would this be the case? The PyTorch 2.0.0 binaries ship with CUDA 11.7 and 11.8 and all needed dependencies. You would this only need to install a proper NVIDIA driver, which support minor-version compatibility since CUDA 11. Were you using a CUDA 10.2 build?

Sadly, it seems `functorch` package is not available for PyTorch 1.12.1 as well.

``````>>> import functorch
ModuleNotFoundError: No module named 'functorch'
``````

The cluster I’m running code on is using CUDA 10.2 and according to the administrator upgrading CUDA is not planned for now unfortunately. So I guess PyTorch 2.0 specific features won’t be available for me anytime soon.

Ah OK, that explains it. Major CUDA versions need a driver update (unless you are using the compat driver package on data center GPUs). Note that PyTorch does not support CUDA 10.2 anymore, so you would eventually need to update if you want to use any newer release.

Hi @MrCrHaM,

I read the `functorch` docs, for your case, you can install `functorch` via `pip install functorch`, more info in the docs here: Install functorch — functorch nightly documentation

You should be able to install with conda too if you need to, by just locally installing to pip inside your conda environment.

Also, you will get warned that you’re using a depracted package and to use `torch.func` instead, just make sure to suspress the warnings with the `warnings` package (should be installed already) and things should be fine

Thanks for the info. I’ll send these words to our cluster administrator.

Thanks for the further instruction. Module `functorch` works like a charm.