I am trying to perform matrix multiplication of multiple matrices in PyTorch and was wondering what is the equivalent of numpy.linalg.multi_dot() in PyTorch?

If there isn’t one, what is the next best way (in terms of speed and memory) I can do this in PyTorch?

Code:

import numpy as np
import torch
A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)
results = np.linalg.multi_dot(A, B, C)
A_tsr = torch.tensor(A)
B_tsr = torch.tensor(B)
C_tsr = torch.tensor(C)
# What is the PyTorch equivalent of np.linalg.multi_dot()?

I don’t think there is an equivalent function for that in PyTorch (not sure), but I just looked at numpy’s source code and it seems pretty simple as all functions are available in PyTorch to implement same method as numpy. It would not take much time as you just need to replace torch instead of numpy in some cases and maybe few tricks.

That’s really nice, maybe you can create a PR on official PyTorch github repo and help other people to use it too.
AFAIK, PyTorch is trying to replicate all methods in np.linalg, so maybe you can add this to that list and take the responsibility. Linear Algebra tracking issue · Issue #42666 · pytorch/pytorch (github.com)

You can find multi_dot in the list of todos (Planned for PyTorch 1.9), and has not been assigned to anyone apparently.