I know that Pytorch can handle batch matrix multiplication, like (B, X, Y) * (B, Y, Z) → (B, X, Z).
But what if the matrices had two common dimensions? I want (B1, B2, X, Y) * (B1, B2, Y, Z) → (B1, B2, X, Z), and it doesn’t seem to be handled by the default multiplication. Is there a good way?
Yes, that’s actually what I’ve been doing. But code became dirty and I was just wondering if there’s a single function that took care of this, since there seems to be like a million matrix manipulating functions here…