Efficient Distance Matrix Computation

Hello. I am doing a project that need to calculate this distance too. Firstly I want to express my gratitude to you for sharing such elaborate approach.
I firstly just deal with the calculation by approch A and out of memory occurs, in the pytorch document it seems that there are not good solution to date.
I had no idea calculating (a-b)^2 by a^2+b^2-2ab, I think this is the best way for speed and memory consideration by using ready-made pytorch api.
I think the best way to deal with numerical stability may be writing a cffi extension and using cuda to calculating the distance directly.