Fastest way to compute multiple outer products

Given a 2d matrix of size (2000x1000) i need to compute the outer product of each row with itself. Finally all the outer products must be averaged. What is the fastest/most efficient way of doing so?
This is by far the biggest bottle neck in my program. Ive come up with 2 solutions. The 2. of which, against all intuition, is way faster than the first. Any help to improve the speed of this computation will be greatly appriciated. The computation will be performed by the GPU if its of any relevance.

Slowest:

def avg_matrix_outer_products_v1(a):
    x_dim, y_dim = a.shape
    ourter_products = torch.matmul(a.view(x_dim, y_dim, 1), a.view(x_dim, 1, y_dim)).T
    return torch.mean(ourter_products , 2)

This method can run into memory issues but this is easily fixed by splitting “a” into submatracies with only 250-500 rows and calling the function multiple times.

Fastest:

def avg_matrix_outer_products_v2(a):
     x_dim = a.shape[0]
     ourter_products = torch.outer(a[0], a[0])
     for j in range(1, x_dim):
         ourter_products += torch.outer(a[j], a[j])

     return ourter_products / x_dim

If you have a matrix of size 2000 x 1000 and you won’t the outer product of the i-th row, you can use torch.einsum.

So, for example having a matrix that’s shape [i,j] doing torch.einsum("ij->ijj", a) will return the outer product for you.

Hi Aesteban!

The calculation you are asking for can be performed with a single
matrix multiplication:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> def avg_matrix_outer_products_v2(a):
...      x_dim = a.shape[0]
...      ourter_products = torch.outer(a[0], a[0])
...      for j in range(1, x_dim):
...          ourter_products += torch.outer(a[j], a[j])
...      return ourter_products / x_dim
...
>>> a = torch.randn (2000, 1000)
>>> outA = avg_matrix_outer_products_v2(a)
>>> outB = a.T @ a / a.shape[0]
>>> outB.allclose (outA, atol = 1.e-7)
True

(And, yes, I gave up on your _v1 version because it was annoying
slow and memory inefficient.)

Best.

K. Frank

2 Likes

Thanks alot! Works really well! My program just gained a 10x speed up lol.