Fastest way to compute multiple outer products

Hi Aesteban!

The calculation you are asking for can be performed with a single
matrix multiplication:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> def avg_matrix_outer_products_v2(a):
...      x_dim = a.shape[0]
...      ourter_products = torch.outer(a[0], a[0])
...      for j in range(1, x_dim):
...          ourter_products += torch.outer(a[j], a[j])
...      return ourter_products / x_dim
...
>>> a = torch.randn (2000, 1000)
>>> outA = avg_matrix_outer_products_v2(a)
>>> outB = a.T @ a / a.shape[0]
>>> outB.allclose (outA, atol = 1.e-7)
True

(And, yes, I gave up on your _v1 version because it was annoying
slow and memory inefficient.)

Best.

K. Frank

2 Likes