import torch
h, w = 8, 8
c = 2
a = torch.randn(h*w, c)
S_mat = torch.cdist(a, a)
i_ind = torch.tensor([12, 14, 51]).long()
j_ind = torch.tensor([21, 60, 33]).long()
result = S_mat[i_ind, j_ind]
In fact, the tensor a
is very large, suppose h, w = 100, 100
, so I do not want to calculate distance of all the pairs of points, instead only serveral selected indexes. Is there any way to perform it ?
1 Like
Is there a pattern to the way you want to perform this? Or is it completely random?
a
is random and i_ind
and j_ind
are random too.
import torch
import torch.nn.functional as F
ann_num = 3
h, w = 8, 8
c = 2
a = torch.randn(h*w, c)
S_mat = torch.cdist(a, a)
i_ind = torch.tensor([[12, 14, 51], [11, 6, 32], [1, 5, 17]]).long() # 3 ann_num
j_ind = torch.tensor([[21, 60, 33], [13, 16, 12], [7, 55, 29]]).long()
print('S_mat')
print(S_mat)
result = S_mat[i_ind, j_ind]
print('method 1 result:')
print(result)
# Only compute distances of selected set of paris.
input_sample_i = a[i_ind, :]
input_sample_j = a[j_ind, :]
S_mat_sample2 = F.pairwise_distance(input_sample_i.view(-1, c), input_sample_j.view(-1, c))
print('S_mat_smaple2')
print(S_mat_sample2)
I’ve figured out this problem.
1 Like