Hi Aesteban!
The calculation you are asking for can be performed with a single
matrix multiplication:
>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> def avg_matrix_outer_products_v2(a):
... x_dim = a.shape[0]
... ourter_products = torch.outer(a[0], a[0])
... for j in range(1, x_dim):
... ourter_products += torch.outer(a[j], a[j])
... return ourter_products / x_dim
...
>>> a = torch.randn (2000, 1000)
>>> outA = avg_matrix_outer_products_v2(a)
>>> outB = a.T @ a / a.shape[0]
>>> outB.allclose (outA, atol = 1.e-7)
True
(And, yes, I gave up on your _v1
version because it was annoying
slow and memory inefficient.)
Best.
K. Frank