Compute batched L2-distance between 2 vectors - torch.cdist()

I have 2 vectors ‘a’ and ‘b’ of shapes: a = (batch_size, embed_dim) and b = (num_neurons, embed_dim), and I want to compute L2-distance between them such that the distance has the shape: (batch_size, num_neurons).

nr_units = 200
z_dim = 128

a = nn.Embedding(num_embeddings = nr_units, embedding_dim = z_dim)

# Initialize weights of 'a'-
a.weight.data.uniform_(-np.sqrt(1 / z_dim), np.sqrt(1 / z_dim))

b = torch.rand(1024, z_dim)

# Compute L2-distance using torch.cdist()-
dist_cdist = torch.cdist(
    x1 = a, x2 = b.weight,
    p = 2, compute_mode = 'donot_use_mm_for_euclid_dist'
)

While searching for similar questions, I came across this thread which has this comment:

This is an expected behavior, given that mm distance is calculated as a^2-2a b+b^2. If a and b are close, then catastrophic cancellation happens and result is imprecise (can even be negative, but we clamp it to be non-negative).
Using matrix multiplication dramatically improves performance, but if accuracy is not enough for your application, then you should not use it.

In my use-case, ‘a’ and ‘b’ should be very close to each other and according to this comment, catastrophic cancellation might ensue?! So, I am using:

compute_mode = ‘donot_use_mm_for_euclid_dist’

in torch.cdist() computation. Is this the correct approach to handle the case when a and b are close?

I found another thread which discusses pair-wise distance computation but seems to be old.

Therefore, which is the new and correct pair-wise batched way/method to compute L2-distance between ‘a’ and ‘b’ even when they are quite close to each other?