Hi Shon!
It’s not entirely clear to me what the values num_cats
and k
are.
Let me assume that you want the shape of your final “all the multiplication
together” result to be [batch_size, num_cats, k, k]
.
You can compute your three-term product with einsum()
:
result = torch.einsum ('bnkf, bnfg, bnlg -> bnkl', all_X, all_C, all_Y)
Make sure you understand how einsum()
works so that you are
contracting (“multiplying” together) the desired pairs of indices.
If all_C
is constructed from the weight
s of Linear
s, you may need
to transpose the last two dimensions (the ffnn
dimensions) of all_C
,
either explicitly or by swapping the symbolic indices in the einsum()
expression, in order to replicate the result you would get by applying
the Linear
to all_Y
.
(This again ignores any bias
terms that may have appeared in your
Linear
s.)
Best.
K. Frank