Sparse calucating of `cdist` for some specific pairs

import torch
h, w = 8, 8
c = 2
a = torch.randn(h*w, c)

S_mat = torch.cdist(a, a)

i_ind = torch.tensor([12, 14, 51]).long()
j_ind = torch.tensor([21, 60, 33]).long()

result = S_mat[i_ind, j_ind]

In fact, the tensor a is very large, suppose h, w = 100, 100, so I do not want to calculate distance of all the pairs of points, instead only serveral selected indexes. Is there any way to perform it ?

1 Like

Is there a pattern to the way you want to perform this? Or is it completely random?

a is random and i_ind and j_ind are random too.

import torch
import torch.nn.functional as F

ann_num = 3
h, w = 8, 8
c = 2
a = torch.randn(h*w, c)

S_mat = torch.cdist(a, a)

i_ind = torch.tensor([[12, 14, 51], [11, 6, 32], [1, 5, 17]]).long() # 3 ann_num
j_ind = torch.tensor([[21, 60, 33], [13, 16, 12], [7, 55, 29]]).long()
print('S_mat')
print(S_mat)

result = S_mat[i_ind, j_ind]

print('method 1 result:')
print(result)


# Only compute distances of selected set of paris.
input_sample_i = a[i_ind, :]
input_sample_j = a[j_ind, :]

S_mat_sample2 = F.pairwise_distance(input_sample_i.view(-1, c), input_sample_j.view(-1, c))
print('S_mat_smaple2')
print(S_mat_sample2)

I’ve figured out this problem.

1 Like