Matrix multiplication with two common dimensions

I know that Pytorch can handle batch matrix multiplication, like (B, X, Y) * (B, Y, Z) → (B, X, Z).

But what if the matrices had two common dimensions? I want (B1, B2, X, Y) * (B1, B2, Y, Z) → (B1, B2, X, Z), and it doesn’t seem to be handled by the default multiplication. Is there a good way?

I might misunderstand your use case but couldn’t you flatten the first two dimensions into one and unflatten them afterwards?

Yes, that’s actually what I’ve been doing. But code became dirty and I was just wondering if there’s a single function that took care of this, since there seems to be like a million matrix manipulating functions here… :wink: