Chained broadcasting matmul without for loop

I have multiple tensors mixed with batched matrices and broadcasted matrices. Is there an efficient and elegant way to get the matrix product of them?

Following snippet is how it’s implemented currently which is ugly. BTW, for loop is better avoided due to performance reason.

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
tensor3 = torch.randn(10, 5, 6)
tensor4 = torch.randn(6, 7)
tensor5 = torch.randn(10, 7, 8)

result = torch.matmul(
print(result.size())  # torch.Size([10, 3, 8])

Hi @MrCrHaM,

You can get the desired output via using torch.func.vmap and torch.linalg.multi_dot, here’s a minimal reproducible example.

import torch
from torch.linalg import multi_dot
from torch.func import vmap

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
tensor3 = torch.randn(10, 5, 6)
tensor4 = torch.randn(6, 7)
tensor5 = torch.randn(10, 7, 8)

def func(A,B,C,D,E):
  return multi_dot((A,B,C,D,E)) #notice the 2nd ()'s here.
out = vmap(func, in_dims=(0,None,0,None,0))(tensor1,tensor2,tensor3,tensor4,tensor5)
print(out.shape) #returns torch.Size([10, 3, 8])
1 Like

Hi @AlphaBetaGamma96,

Thanks a lot for the great answer.

However, I’m currently using PyTorch 1.12.1. It seems that torch.func is not available, and upgrading to 2.0 would involve uprading CUDA and NVIDIA drive accordingly which is not an option for the cluster for now.

Any ideas to solve it without any PyTorch 2.0 specific features?

You can change torch.func.vmap to functorch.vmap, the functorch package should have been installed alongside pytorch when you initially installed it.

Why would this be the case? The PyTorch 2.0.0 binaries ship with CUDA 11.7 and 11.8 and all needed dependencies. You would this only need to install a proper NVIDIA driver, which support minor-version compatibility since CUDA 11. Were you using a CUDA 10.2 build?

Sadly, it seems functorch package is not available for PyTorch 1.12.1 as well.

>>> import functorch
ModuleNotFoundError: No module named 'functorch'

The cluster I’m running code on is using CUDA 10.2 and according to the administrator upgrading CUDA is not planned for now unfortunately. So I guess PyTorch 2.0 specific features won’t be available for me anytime soon. :face_exhaling:

Ah OK, that explains it. Major CUDA versions need a driver update (unless you are using the compat driver package on data center GPUs). Note that PyTorch does not support CUDA 10.2 anymore, so you would eventually need to update if you want to use any newer release.

Hi @MrCrHaM,

I read the functorch docs, for your case, you can install functorch via pip install functorch, more info in the docs here: Install functorch — functorch nightly documentation

You should be able to install with conda too if you need to, by just locally installing to pip inside your conda environment.

Also, you will get warned that you’re using a depracted package and to use torch.func instead, just make sure to suspress the warnings with the warnings package (should be installed already) and things should be fine

Thanks for the info. I’ll send these words to our cluster administrator. :joy:

Thanks for the further instruction. Module functorch works like a charm.