I want to compute XW_kY for a list of W_k, k=1,…,K, where K is around 10 or so (but is a parameter). How might I parallelize this? My best guess is to change them all into block_diags, but I imagine there is a more efficient way to do this.
I think you can use torch.einsum for that:
import torch
k = 10
a=2
b=3
c=4
d=5
X = torch.randn(a, b)
Ws = [torch.randn(b, c) for _ in range(k)]
W = torch.stack(Ws, dim=0) # shape: [k, b, c]
Y = torch.randn(c, d)
result = torch.einsum("ab,kbc,cd->kad", X, W, Y) # shape: [k, a, d]
Note that since this is an einsum involving more than just 2 tensors, you could maybe have improved performances if you just install opt-einsum (nothing else to do, torch.einsum will detect that it’s installed and use it to optimize the order of operations). pip install opt-einsum should be enough for that.
1 Like
Thanks, that did the trick! I also had a batch dimension to deal with, but “nab,kbc,ncd->nkad” handled that.
1 Like