Now I have a tensor farthest_idxs of size (Batch, Feature) = (24, 32). I also have a tensor nearest_idxs of a size (points, Batch, feature) = (1024, 24, 1023). For one point p and one sample s (i.e. nearest_idxs[p,s,:] of size 1x1023), I want to find the first element in this vector that in the farthest_idxs[s,:] (size 1x32), and return a matrix to record the result (size 24x1024). Is there any efficient way to implement it?
Here is my code, which is a unefficient way to implement it.
def nearest_indices(self, relation, farthest_idxs): '''Generate the nearest indices return: [B, N] matrix ''' device = relation.device nearest_value, nearest_idxs = torch.topk(relation, k=1023, dim=2, largest=False, sorted=True) print('nearest idxs', nearest_idxs) nearest_idxs = nearest_idxs.transpose(0,1) # 1024x24x1023 print('transposed nearest_idxs', nearest_idxs.shape) N, B, P = nearest_idxs.shape upsample_idxs = torch.zeros((B, N), dtype=torch.long).to(device) for n in range(N): for b in range(B): for p in range(P): if nearest_idxs[n, b, p] in farthest_idxs[b,:]: upsample_idxs[b, n] = nearest_idxs[n, b, p] break print(upsample_idxs.shape)