Hi Rane!
It’s relatively straightforward to implement a loop-free batch version of your
cov()
function. Furthermore, torch.linalg.inv()
(which torch.inverse()
is an alias for) computes matrix inverses on a batch basis if you pass in a
tensor with more than two dimensions. Together, these should give you the
tools to implement a loop-free version of your Corr
loss.
Here is an illustration of a batch-cov()
and batch-inv()
:
>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> def cov(x, y):
... x_bar = x - x.mean(axis=0)
... y_bar = y - y.mean(axis=0)
... N = x_bar.shape[0]
... return (y_bar.T @ x_bar).T / (N - 1)
...
>>> def batchCov (x, y):
... x_bar = x - x.mean (1).unsqueeze (1)
... y_bar = y - y.mean (1).unsqueeze (1)
... N = x_bar.shape[1]
... return (y_bar.transpose (1, 2) @ x_bar).transpose (1, 2) / (N - 1) # @ will do batch mm
...
>>> x = torch.randn (3, 5, 5)
>>> y = torch.randn (3, 5, 5)
>>>
>>> torch.allclose (batchCov (x, y)[2], cov (x[2], y[2]))
True
>>>
>>> torch.allclose (torch.linalg.inv (x)[2], torch.linalg.inv (x[2]))
True
Best.
K. Frank