Hi,
I’ve been having trouble computing my desired loss: I have two input tensors <X, Y>, with shapes: (batch_size, num_rows, num_columns). The loss I want to compute is the correlation between the respective <X, Y> columns for each one of the samples.
So far I’ve been iterating through each sample:
def cov(x, y):
x_bar = x - x.mean(axis=0)
y_bar = y - y.mean(axis=0)
N = x_bar.shape[0]
return (y_bar.T @ x_bar).T / (N - 1)
class Corr(nn.Module):
def __init__(self, device, eps=1e-5):
super(Corr, self).__init__()
self.device = device
self.eps = eps
def forward(self, x, y, params):
loss = 0
for i in range(x.shape[0]):
xi = x[i, :, :]
yi = y[i, :, :]
C_yy = cov(yi, yi)
C_yx = cov(yi, xi)
C_xx = cov(xi, xi)
C_yy = C_yy + \
torch.eye(C_yy.shape[0], device=self.device) * self.eps
C_xx = C_xx + \
torch.eye(C_xx.shape[0], device=self.device) * self.eps
M = torch.linalg.multi_dot([torch.inverse(C_yy),
C_yx,
torch.inverse(C_xx),
C_yx.T])
loss += torch.trace(M)
return (loss/x.shape[0])
Is there a faster way to implement this?