I have a tensor called classification with size (B,1,2048) with B the batch size.
From it I extract the indices of the top k1 and k2 elements for each batch, so I have another two tensors of size (B,1,k1) and (B,1,k2).
In this case k1 and k2 are both 1024 but in general they can be also different.
Then I have another tensor called grasp_labels of shape (B,N,2) with N pairs of indices. I have to build a new tensor combined_labels with shape (B,k1,k2) that contains 1 in position (b,i,j) if (b,i,j) is in grasp_labels and i is in top_k1 and j is in top_k2 (in the right batch), 0 otherwise.
The problem I just can concibe this using for loops, what is not a good idea because I lose the gradient flow. I would need to find a pytorch vecrorized way to do this.
Pseudocode with for loops:
for b in range(0,B): for i in range(0,k1): for j in range(0,k2): a = topk1[b,1,i] c = topk2[b,1,j] if [a,c] in grasps_labels[b,:,:]: combined_labels[b,i,j] = 1
Any idea?? Thinking in more than two dimensions sometimes blow my brain.
Thanks a lot in advance!