Batch matrix multiplication of 3D tensors

Hi Shon!

It’s not entirely clear to me what the values num_cats and k are.

Let me assume that you want the shape of your final “all the multiplication
together” result to be [batch_size, num_cats, k, k].

You can compute your three-term product with einsum():

result = torch.einsum ('bnkf, bnfg, bnlg -> bnkl', all_X, all_C, all_Y)

Make sure you understand how einsum() works so that you are
contracting (“multiplying” together) the desired pairs of indices.

If all_C is constructed from the weights of Linears, you may need
to transpose the last two dimensions (the ffnn dimensions) of all_C,
either explicitly or by swapping the symbolic indices in the einsum()
expression, in order to replicate the result you would get by applying
the Linear to all_Y.

(This again ignores any bias terms that may have appeared in your
Linears.)

Best.

K. Frank