Hello, I’ve been trying to implement the Mahalanobis distance between multiple nodes. So far I’ve started with an input of size (batch, time_stamps, num_nodes, embeding_size) and I will like to have an output of size (batch_size, time_stamps, num_nodes, num_nodes). The computation is quite simple for each pair of nodes (x_i, x_j, where the batch and the time_stamp matches) I need to compute: (x_i - x_j)^T * M * (x_j - x_i). I managed to implement by looping, but it was too time consuming, I also implemented it by using torch.repeat and some reshapes but i’m running out of memory.

So far I have this:

```
sub1 = x_perm.repeat([1, 1, num_nodes, 1]).reshape(batch_size, ts, num_nodes, num_nodes,
num_channels).transpose(
3, 2) - \
x_perm.repeat([1, 1, num_nodes, 1]).reshape(batch_size, ts, num_nodes,
num_nodes, num_channels)
sub2 = x_perm.repeat([1, 1, num_nodes, 1]).reshape(batch_size, ts, num_nodes,
num_nodes, num_channels) - \
x_perm.repeat([1, 1, num_nodes, 1]).reshape(batch_size, ts, num_nodes,
num_nodes, num_channels).transpose(
3, 2)
sub1 = sub1.unsqueeze(4)
sub2 = sub2.unsqueeze(4)
score = torch.matmul(sub1, M])
score = torch.matmul(score, sub2.transpose(5, 4))
score = score.reshape(batch_size, ts, num_nodes, num_nodes)
```

Any suggestions are more than welcome.