Computing bilinear with block diagonal matrix efficiently

I want to compute $x^TAx$ where $x$ is a vector and A is a matrix (i.e. bilinear form). Although nn.Bilinear would work, in my case $A$ is block-diagonal. So, I wanted to create a nn.Module that would compute this using much less memory and computation by only remembering the block diagonal matrices.

For example, Below is an example of computing, using nn.Bilinear(x,x,A,bias) (here x’s batch dim is 3 and output dim is 2

x = tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
A = tensor([[[ 0.,  1.,  0.,  0.],
         [ 2.,  3.,  0.,  0.],
         [ 0.,  0., -0., -1.],
         [ 0.,  0., -2., -3.]],

        [[ 0.,  2.,  0.,  0.],
         [ 4.,  6.,  0.,  0.],
         [ 0.,  0., -0., -2.],
         [ 0.,  0., -4., -6.]]])
b = tensor([0, 1])

x.shape, weight.shape, bias.shape torch.Size([3, 4]) torch.Size([2, 4, 4]) torch.Size([2])
and the result of nn.Bilinear is

tensor([[ -42.,  -83.],
        [-138., -275.],
        [-234., -467.]])

Initially, I tried the following approach : since the quadratic form of block diagonals can be made equal to the sum of smaller chunks of bilinear (i.e. if x = [x1,x2] and A = [[A1,0][0,A2]], then $x^TAx = x_1^TA_1x_1 + x_2^TA_2x_2$), i tried to divide the x into chunks equal to the size of the block matrices, then send reshape it s.t. the batch dimension now is multiplied by the number of block matrices, then computed the bilinear using the increase batch size, but it didn’t work as expected.

Any suggestions or ideas would be greatly appreciated!
(I thought of running a for loop of torch.matmul over each block diagonal could work, but I am afraid that would not be parallel and hence decrease speed)

Hi Kore (aka Danny)!

I don’t believe that there is any way (without a loop) to perform your
computation with Bilinear that doesn’t have inefficiencies such as
multiplying things with the zero blocks.

Probably the simplest way to perform the necessary tensor multiplications
is to use einsum().

Let’s assume that you store just the non-zero blocks of A (or pre-process
A to extract its non-zero blocks). You can then perform the bilinear
“contraction” with a .reshape() of x and einsum():

import torch

x = torch.tensor([[ 0., 1.,  2.,  3.],
                  [ 4., 5.,  6.,  7.],
                  [ 8., 9., 10., 11.]])
A = torch.tensor([[[ 0.,  1.,  0.,  0.],
                   [ 2.,  3.,  0.,  0.],
                   [ 0.,  0., -0., -1.],
                   [ 0.,  0., -2., -3.]],
                  [[ 0.,  2.,  0.,  0.],
                   [ 4.,  6.,  0.,  0.],
                   [ 0.,  0., -0., -2.],
                   [ 0.,  0., -4., -6.]]])
b = torch.tensor([0., 1.])

resultA = torch.nn.functional.bilinear (x, x, A, b)

print ('resultA = ...')
print (resultA)

A_block = torch.tensor([[[[ 0.,  1.],      # store just the non-zero blocks of A
                          [ 2.,  3.]],
                         [[-0., -1.],
                          [-2., -3.]]],
                        [[[ 0.,  2.],
                          [ 4.,  6.]],
                         [[-0., -2.],
                          [-4., -6.]]]])

resultB = torch.einsum ('mij, nijk, mik -> mn', x.reshape (-1, 2, 2), A_block, x.reshape (-1, 2, 2)) + b

print ('resultB = ...')
print (resultB)

print ('torch.equal (resultA, resultB) =', torch.equal (resultA, resultB))

Here is the above script’s output:

resultA = ...
tensor([[ -42.,  -83.],
        [-138., -275.],
        [-234., -467.]])
resultB = ...
tensor([[ -42.,  -83.],
        [-138., -275.],
        [-234., -467.]])
torch.equal (resultA, resultB) = True


K. Frank

Thank you so much for your answer!

I actually solved it myself, using a similar method, where I added an additional dimension (dimension with # of blocks), then used torch.matmul twice. However, your solution seems to be more general (I think it might be helpful when I need to implement a convolution version of this operation)

If you don’t mind, I have just a few questions

  1. Is defining custom operations (i.e. instead of using things like F.Linear) detrimental to speed? Does it affect speed to the point where I need to try out different implementations of the same operation then measure their time and choose the best one?
  2. Are there any other essential operations that I should know? (things like einsum?) It’s my first time making a custom nn.Module, and it seems that knowing which operations to use when is really important to writing code that does not use loops?
  3. If I wanted to expand this operation to a matrix made of different block sizes, would it be efficient, or even possible?

Again, thank you so much :slight_smile:

Best, Danny Han

Hi Danny!

In general, you can replace einsum() with appropriately chosen
permute()s, matmul()s, and sum()'s. (This is, roughly speaking,
what einsum() does under the hood.)

But einsum() is really useful and I find it easier to read than various
matmul() implementations for complicated contractions (such as yours).

(Note, it goes the other way, too. You can replace a matmul() that is
performing a single matrix multiplication with einsum(). But in that case
I would find matmul() easier to read.)

In general no, if you implement your custom operation with pytorch
tensor operations, avoiding explicit python loops. Of course, you can
do things in a suboptimal way, but it’s up to you not to do that.

For most realistic applications, the big work is in performing tensor
multiplications of (large) tensors or applying functions element-wise
to (large) tensors. Pytorch does these things with pytorch tensor
operations that are optimized for gpu (and cpu) floating-point pipelines.

The python glue that you write to chain such operations together is,
in a sense, inefficient, but for most practical problems, the time spent
executing the “inefficient” python is small compared to the real work.

For cases where you use python to chain together a lot of small tensor
operations, you might find that using things like torch.compile speeds
things up dramatically.

Sometimes yes. If speed is very important to your use case, it can
make sense to time alternative implementations. Note that the fastest
approach can depend (in unexpected ways) on the size of the tensors,
gpu vs. cpu, single vs. double precision, and one hardware platform
vs. another.

As an example, there are cases discussed in this forum where einsum()
significantly underperformed or outperformed an equivalent matmul()
implementation. I’m not sure that this is still the case because where
the discrepancy was significant, it was fair to call it a performance bug
in einsum() or matmul(), so those bugs have likely been fixed.

But just do a good job writing your code the first time around, and don’t
waste your time performance-tuning various pieces of code unless you’ve
identified them as bottlenecks in your overall computation. If you make
something that accounts for only 1% of your runtime ten times faster,
who cares?

einsum() is probably one of the best examples of a pytorch function
that can be broadly useful, but is underused.

expand() is useful when you need to repeat slices of a tensor (say for
operations that don’t support broadcasting). It gives you an expanded
view of the tensor, rather than materializing the expanded tensor, which
can offer a large performance benefit.

To the Forum: Help out here! What are some other underloved pytorch
operations that people should put in their general-purpose toolkit?

This would be harder. Pytorch does not support “ragged” tensors, that
is, tensors whose slices have differing shapes. As a simple example,
a “matrix” whose first row has five elements and whose second row has
three elements would be ragged.

One approach is to pad the ragged slices up to some common uniform
size. (This is conceptually similar to your including zero-blocks in your
full block-diagonal matrix.)

You might be able to package your matrix with blocks of different sizes
as a sparse tensor, but sparse tensors in pytorch are only partially
supported, so you might not be able to perform all of the steps you want.

So, if worst comes to worst, you may have to write a python loop that
processes the differently-sized blocks one by one (and maybe use
torch.compile to undo some of the python-loop damage).


K. Frank