Batch matrix multiplication of 3D tensors


I have series of matrix multiplication in a for loop structure, I want to transform it to one “big” matrix to do all the multiplication together to better utilize the GPU.

Here is the current implementation:

The model input x, y in shape of [batch_size, k, config.hidden_size].
For each category id [0, 1, 2, 3] we compute:

  1. non-linear reps for x and ycat_x_reps, cat_y_reps
  2. multiple cat_x_reps, cat_y_reps with a bilinear matrix
    cat_x_reps x C_i x cat_y_reps
import torch
from torch import nn
from transformers.activations import ACT2FN

class FullyConnectedLayer(Module):
    def __init__(self, config, input_dim, output_dim, dropout_prob):
        super(FullyConnectedLayer, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.dropout_prob = dropout_prob

        self.dense = Linear(self.input_dim, self.output_dim)
        self.layer_norm = LayerNorm(self.output_dim, eps=config.layer_norm_eps)
        self.activation_func = ACT2FN[config.hidden_act]
        self.dropout = Dropout(self.dropout_prob)

    def forward(self, inputs):
        temp = inputs
        temp = self.dense(temp)
        temp = self.activation_func(temp)
        temp = self.layer_norm(temp)
        temp = self.dropout(temp)
        return temp

class MyClass(nn.Module):
    def __init__(self, config, args):
        self.ffnn_size = args.ffnn_size
        self.num_heads = args.num_heads
        self.x_cat_mlp = nn.ModuleList([FullyConnectedLayer(config, config.hidden_size, self.ffnn_size, args.dropout_prob)
                                        for _ in range(self.num_heads)])
        self.y_cat_mlp = nn.ModuleList([FullyConnectedLayer(config, config.hidden_size, self.ffnn_size, args.dropout_prob)
                                        for _ in range(self.num_heads)])
        self.classifiers = nn.ModuleList([Linear(self.ffnn_size, self.ffnn_size) 
                                          for _ in range(self.num_heads)])

def forward(self, x_reps, y_reps):
         # x,  y ->  [batch_size, k, config.hidden_size]
        cat_logit_list = []
        for cat_id in range(self.num_heads):
            cat_x_reps = self.x_cat_mlp[cat_id](x)
            cat_y_reps = self.y_cat_mlp[cat_id](y)

            temp = self.classifiers[cat_id](cat_x_reps)                                    # [batch_size, k, ffnn_size]
            cat_logits = torch.matmul(temp, cat_y_reps.permute([0, 2, 1])).  # [batch_size, k, k]

        return torch.stack(cat_logit_list, dim=1)

I was able to do it for the attribute self.x_cat_mlp because all the operations on the same input (x, y).

I defined:

self.all_cats_size = self.ffnn_size * self.num_cats
self.x_cat_mlp = FullyConnectedLayer(config, config.hidden_size, self.all_cats_size, args.dropout_prob)
x_reps = self.x_cat_mlp(x)                                                          # [batch, k, ffnn * num_cats]
cat_x_reps = x_reps.view((batch_size, self.num_cats, k, self.ffnn_size))            # [batch, num_cats, k, ffnn]

I am struggling to do it for self.classifiers. the input is different each time, i.e cat_x_reps is different for each category

Hi Shon!

My interpretation of the above is that:

            temp = self.classifiers[cat_id](cat_x_reps)

is the application of a Linear to cat_x_reps, and corresponds to
C_i x cat_y_reps in your summary.

Note that applying a Linear is not just a matrix multiplication, but
also entails adding the bias term:

linear (t) == t @ linear.weight.T + linear.bias

So describing C_i x cat_y_reps as just matrix multiplication is an

Leaving the issue of the bias term aside for the moment, you can
stack() the weight tensors of your classifiers (along dim = 1, if
your choose), stack() the cat_x_reps and stack() the cat_y_reps,
and then use torch.einsum() to compute all of the three-term products,
cat_x_reps x C_i x cat_y_reps, for all values of i (which I take to be
cat_id) all at once in one “batch.”

Now for the bias term:

cat_x_reps x (C_i x cat_y_reps + bias_i) is equal to
cat_x_reps x C_i x cat_y_reps + cat_x_reps x bias_i so you can
compute the three-term product as above, and then stack() the bias_i,
compute cat_x_reps x bias_i for all i all at once (using, for example,
torch.einsum()), and then add it to the three-term product to get the
final result.


K. Frank

Hi @KFrank,
Thank you for your response!

Just to verify:
After stacking, the shapes are:

all_X  [btach_size, num_cats, k,     ffnn]
all_C  [btach_size, num_cats, ffnn,  ffnn]
all_Y  [btach_size, num_cats, k,     ffnn]

Can you elaborate how to use torch.einsum() for these shapes? we need to have:

for i in range(num_cat):
    x = all_X[:, i, :, :]                               # [btach_size, k,     ffnn]
    c = all_C[:, i, :, :]                               # [btach_size, ffnn,  ffnn]
    y = all_Y[:, i, :, :]                               # [btach_size, k,     ffnn]

    temp = torch.matmul(x, c)                           # [batch_size, k, ffnn_size]
    logits = torch.matmul(temp, y.permute([0, 2, 1])).  # [batch_size, k, k]

If we will do
all_X x all_C x all_Y
it include all the different combinations, for example: X_i x C_i x Y_j which I don’t need to perform.

Hi Shon!

It’s not entirely clear to me what the values num_cats and k are.

Let me assume that you want the shape of your final “all the multiplication
together” result to be [batch_size, num_cats, k, k].

You can compute your three-term product with einsum():

result = torch.einsum ('bnkf, bnfg, bnlg -> bnkl', all_X, all_C, all_Y)

Make sure you understand how einsum() works so that you are
contracting (“multiplying” together) the desired pairs of indices.

If all_C is constructed from the weights of Linears, you may need
to transpose the last two dimensions (the ffnn dimensions) of all_C,
either explicitly or by swapping the symbolic indices in the einsum()
expression, in order to replicate the result you would get by applying
the Linear to all_Y.

(This again ignores any bias terms that may have appeared in your


K. Frank

Thanks ! this is working, WOW einsum such a powerful method !

k is the sequence length.
num_cats is the number of “learning” matrices we have.

You right, I want [batch_size, num_cats, k, k]

I took your note about the weights’s dim swap.
In addition, all_C is the learnable matrices and its shape is [num_cats, ffnn, ffnn]

I am a bit struggling to add the bias term as you suggested:
after calc the three-term product I got a tensor of shape logits = [batch_size, num_cats, k, k]

now I need to do

cat_x_reps x (C_i x cat_y_reps + bias_i ) is equal to
cat_x_reps x C_i x cat_y_reps + cat_x_reps x bias_i so you can
compute the three-term product as above, and then stack() the bias_i ,
compute cat_x_reps x bias_i for all i all at once (using, for example,
torch.einsum() ), and then add it to the three-term product to get the
final result.

So I stacked all the bias term → all_bias = [num_cats, ffnn]
And sum:
logits + torch.einsum('bnkf, nf -> bnk', all_x, all_bias).unsqueeze(-1)

but this isn’t reproduce the same performance (not so far from the for loop method but still behind).

BTW, I ran the for loop method with bias=False when initizle the Linear attribute, then calc the performance(loss).
Then I did the same only with the 3-term product and I got exact the same loss.
So I only struggling with the bias. any idea why?

@KFrank Thanks in advance you helped me a lot!!

Hi Shon!

I’m not sure what you’re asking.

Is your issue that einsum() runs more slowly than the loop (but gives the
same answer)?

Or is einsum() giving you different values than the loop does?

If the latter, check that any differences aren’t due to expected floating-point
round-off error. You might try comparing the results with torch.allclose(), or
performing the computations with double-precision to see if that reduces the


K. Frank

I found the issue,
The linear transformation with the for loop performed on X and the result multiplied by Y.

So instead of
cat_x_reps x C_i x cat_y_reps + cat_x_reps x bias
I need to do:
cat_x_reps x C_i x cat_y_reps + bias x cat_y_reps

Now I can reproduce the same loss score.

Again, thank you very much :slight_smile: