Pairwise distance calc for large tensors

Hello,

I am wondering if anyone has an idea on how to memory efficiently calculate pairwise distance between points? My batch has lots of data, approx 150000 points per batch. I realise that with torch.cdist, it creates an enormous tensor due to (N^2). I have tried looking around for a solution but it seems that I am not able to find it.

I would be very happy if anyone could share a code snippet for this? I found keops, but unfortunately it’s not supported on Windows.

I know I could also use FAISS knn, but that limits the query windows as it is k-limited, as I would like the points to attend to all points inside the batch.

I am still new and learning this concept, so please bear with me. =) thanks!!

Hi Abraham!

Could you explain your use case a little more?

With the simplest interpretation of what you seem to be doing, you will need a lot
of memory. If you need to store all pairwise distances of 150000 point in memory
at the same time, it will take about 90 GB (assuming four-byte floats).

Normally by “batch” we mean a batch of independent samples that don’t interact in
a substantive way with one another. They are batched together so that they can be
passed though the network at the same time to make more efficient use of the gpu
(or cpu).

If I understand you correctly, you want to compute the distance between one sample’s
point and another sample’s point, but now these samples wouldn’t be being treated
independently. Do you really need to compute all pairwise distances? Do you really
need to store them all in memory at the same time? If so, please provide a little more
detail about your use case,and we’ll see if we can offer any practical (if complicated)
suggestions for how to reduce your peak memory usage.

Best.

K. Frank