Indexing the topk to pairwise list

Hi,
I am looking for an effective way to convert the indices from topk() into a pairwise array without using any for loop…or the most runtime efficient way possible…?

For example,

import torch
import pdb

x_ = torch.randn(10,1)  
key_dist_s = torch.cdist(x_, x_, p=2) 

pdb.set_trace()

idx = torch.topk(key_dist_s, k=3, largest=True) 
idx_indices =  idx.indices

I want to reshape the following idx_indices

tensor([[7, 5, 9],
        [4, 0, 8],
        [7, 4, 5],   
        [4, 1 8]])

into something like this without using for-loop…

tensor([[7, 5], 
           [7,9]
          [4, 0],
          [4, 8],
          [7, 4],
          [7,5], 
          [4, 1]... ]) 

Can anyone please give a solution ?

Do you mean a solution with no python for loops? I think any solution will have some underlying for loops, so hopefully native loops are fast enough here :wink:

import torch
import pdb

x_ = torch.randn(10,1)
key_dist_s = torch.cdist(x_, x_, p=2)

idx = torch.topk(key_dist_s, k=3, largest=True)
idx_indices =  idx.indices
a = idx_indices[:,0:2]
b = idx_indices[:,[0,2]]
print(a)
print(b)
temp = torch.cat((a, b), dim=1)
total_len = temp.size(0)*2
out = temp.reshape(total_len, -1)
print(out)
tensor([[3, 4],
        [5, 3],
        [5, 6],
        [5, 6],
        [5, 6],
        [3, 4],
        [3, 4],
        [5, 3],
        [3, 5],
        [5, 3]])
tensor([[3, 2],
        [5, 4],
        [5, 0],
        [5, 0],
        [5, 0],
        [3, 2],
        [3, 2],
        [5, 4],
        [3, 4],
        [5, 4]])
tensor([[3, 4],
        [3, 2],
        [5, 3],
        [5, 4],
        [5, 6],
        [5, 0],
        [5, 6],
        [5, 0],
        [5, 6],
        [5, 0],
        [3, 4],
        [3, 2],
        [3, 4],
        [3, 2],
        [5, 3],
        [5, 4],
        [3, 5],
        [3, 4],
        [5, 3],
        [5, 4]])

Hi eqy, thanks …

Eventually, I did something like this…

topk = 3
idx_indices =  idx.indices 

sources_sig    = idx_indices[ :, 0].unsqueeze(-1).float()
targets_sig    = idx_indices[ :, 1:].float()

ones_like      = torch.ones( 1, topk-1, device=key_s_gt.device).float()

sources_sig     = torch.matmul(sources_sig, ones_like)

sources_sig_long = sources_sig.long().view(-1)
targets_sig_long = targets_sig.long().view(-1)

pairwise_sig_st = torch.stack([ sources_sig_long, targets_sig_long], dim=1)

Anyway, please feel free to comment for the better alternatives :slight_smile: