Efficient way to calculate the code

Hi there,

Now I have a tensor farthest_idxs of size (Batch, Feature) = (24, 32). I also have a tensor nearest_idxs of a size (points, Batch, feature) = (1024, 24, 1023). For one point p and one sample s (i.e. nearest_idxs[p,s,:] of size 1x1023), I want to find the first element in this vector that in the farthest_idxs[s,:] (size 1x32), and return a matrix to record the result (size 24x1024). Is there any efficient way to implement it?

Here is my code, which is a unefficient way to implement it.

def nearest_indices(self, relation, farthest_idxs):
        '''Generate the nearest indices
            return:
                [B, N] matrix
        '''
        device = relation.device
        nearest_value, nearest_idxs = torch.topk(relation, k=1023, dim=2, largest=False, sorted=True)
        print('nearest idxs', nearest_idxs)
        nearest_idxs = nearest_idxs.transpose(0,1) # 1024x24x1023
        print('transposed nearest_idxs', nearest_idxs.shape)
        N, B, P = nearest_idxs.shape
        upsample_idxs = torch.zeros((B, N), dtype=torch.long).to(device)
        for n in range(N):
            for b in range(B):
                for p in range(P):
                    if nearest_idxs[n, b, p] in farthest_idxs[b,:]:
                        upsample_idxs[b, n] = nearest_idxs[n, b, p]
                        break
        print(upsample_idxs.shape)

Any help?