Mahalanobis distance between all nodes

Hello, I’ve been trying to implement the Mahalanobis distance between multiple nodes. So far I’ve started with an input of size (batch, time_stamps, num_nodes, embeding_size) and I will like to have an output of size (batch_size, time_stamps, num_nodes, num_nodes). The computation is quite simple for each pair of nodes (x_i, x_j, where the batch and the time_stamp matches) I need to compute: (x_i - x_j)^T * M * (x_j - x_i). I managed to implement by looping, but it was too time consuming, I also implemented it by using torch.repeat and some reshapes but i’m running out of memory.

So far I have this:

            sub1 = x_perm.repeat([1, 1, num_nodes, 1]).reshape(batch_size, ts, num_nodes, num_nodes,
                                                               num_channels).transpose(
                3, 2) - \
                   x_perm.repeat([1, 1, num_nodes, 1]).reshape(batch_size, ts, num_nodes,
                                                               num_nodes, num_channels)
            sub2 = x_perm.repeat([1, 1, num_nodes, 1]).reshape(batch_size, ts, num_nodes,
                                                               num_nodes, num_channels) - \
                   x_perm.repeat([1, 1, num_nodes, 1]).reshape(batch_size, ts, num_nodes,
                                                               num_nodes, num_channels).transpose(
                       3, 2)

            sub1 = sub1.unsqueeze(4)
            sub2 = sub2.unsqueeze(4)

            score = torch.matmul(sub1, M])
            score = torch.matmul(score, sub2.transpose(5, 4))
            score = score.reshape(batch_size, ts, num_nodes, num_nodes)

Any suggestions are more than welcome.

I’m wondering if you could take advantage of the batch mahalanobius distance that is utilized internally within the multivariate normal.

Without completely understanding the context of your code, I have a sneaky feeling that those repeats should be removed.