# Custom loss - correlation by column

Hi,

I’ve been having trouble computing my desired loss: I have two input tensors <X, Y>, with shapes: (batch_size, num_rows, num_columns). The loss I want to compute is the correlation between the respective <X, Y> columns for each one of the samples.

So far I’ve been iterating through each sample:

``````def cov(x, y):
x_bar = x - x.mean(axis=0)
y_bar = y - y.mean(axis=0)
N = x_bar.shape
return (y_bar.T @ x_bar).T / (N - 1)

class Corr(nn.Module):
def __init__(self, device, eps=1e-5):
super(Corr, self).__init__()
self.device = device
self.eps = eps

def forward(self, x, y, params):

loss = 0
for i in range(x.shape):
xi = x[i, :, :]
yi = y[i, :, :]
C_yy = cov(yi, yi)
C_yx = cov(yi, xi)
C_xx = cov(xi, xi)

C_yy = C_yy + \
torch.eye(C_yy.shape, device=self.device) * self.eps
C_xx = C_xx + \
torch.eye(C_xx.shape, device=self.device) * self.eps

M = torch.linalg.multi_dot([torch.inverse(C_yy),
C_yx,
torch.inverse(C_xx),
C_yx.T])

loss += torch.trace(M)

return (loss/x.shape)
``````

Is there a faster way to implement this?

Hi Rane!

It’s relatively straightforward to implement a loop-free batch version of your
`cov()` function. Furthermore, `torch.linalg.inv()` (which `torch.inverse()`
is an alias for) computes matrix inverses on a batch basis if you pass in a
tensor with more than two dimensions. Together, these should give you the
tools to implement a loop-free version of your `Corr` loss.

Here is an illustration of a batch-`cov()` and batch-`inv()`:

``````>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> def cov(x, y):
...     x_bar = x - x.mean(axis=0)
...     y_bar = y - y.mean(axis=0)
...     N = x_bar.shape
...     return (y_bar.T @ x_bar).T / (N - 1)
...
>>> def batchCov (x, y):
...     x_bar = x - x.mean (1).unsqueeze (1)
...     y_bar = y - y.mean (1).unsqueeze (1)
...     N = x_bar.shape
...     return (y_bar.transpose (1, 2) @ x_bar).transpose (1, 2) / (N - 1)   # @ will do batch mm
...
>>> x = torch.randn (3, 5, 5)
>>> y = torch.randn (3, 5, 5)
>>>
>>> torch.allclose (batchCov (x, y), cov (x, y))
True
>>>
>>> torch.allclose (torch.linalg.inv (x), torch.linalg.inv (x))
True
``````

Best.

K. Frank

1 Like

Thank you very much @KFrank,

I’ve used your `batchCov` function and modified my `Corr` to make it loop-free.