How I can avoid the for loops in a 3D tensor problem?

Hi guys!

I have a tensor called classification with size (B,1,2048) with B the batch size.

From it I extract the indices of the top k1 and k2 elements for each batch, so I have another two tensors of size (B,1,k1) and (B,1,k2).

In this case k1 and k2 are both 1024 but in general they can be also different.

Then I have another tensor called grasp_labels of shape (B,N,2) with N pairs of indices. I have to build a new tensor combined_labels with shape (B,k1,k2) that contains 1 in position (b,i,j) if (b,i,j) is in grasp_labels and i is in top_k1 and j is in top_k2 (in the right batch), 0 otherwise.

The problem I just can concibe this using for loops, what is not a good idea because I lose the gradient flow. I would need to find a pytorch vecrorized way to do this.

Pseudocode with for loops:

for b in range(0,B):
     for i in range(0,k1):
        for j in range(0,k2):
            a = topk1[b,1,i]
            c = topk2[b,1,j]
            if [a,c] in grasps_labels[b,:,:]:
                combined_labels[b,i,j] = 1

Any idea?? Thinking in more than two dimensions sometimes blow my brain.

Thanks a lot in advance!

I’m a bit confused as to why gradients need to flow as it looks like this is computing labels here. Here’s a sketch of an alternative implementation that doesn’t use Python for loops (although iteration is still heavily used). The main change is that [a,c] in grasps_labels[b,:,:] (which doesn’t seem to be valid PyTorch) is instead done via a lookup with a precomputed grasps_labels_matrixinstead.

Note that this is untested as I don’t have any reference input/output to compare with:

import torch

B = 32
k1 = 1024
k2 = 1024
N = 2048

classification = torch.randn(B, 1, 2048)
topk1 = torch.topk(classification, k1).indices
topk2 = torch.topk(classification, k2).indices

# create a mapping from [b,a,c] to 0,1
grasps_labels = torch.randint(0, 1024, (B, N, 2)).reshape(B*N,2)
grasps_indices = torch.unbind(torch.cat((torch.arange(B).repeat_interleave(N).unsqueeze(1), grasps_labels), axis=1), axis=1)
grasps_labels_matrix = torch.zeros(B, N, N)
grasps_labels_matrix.index_put_(grasps_indices, torch.tensor(1.0))

# collect indices to check
indices_k1_repeat = torch.repeat_interleave(topk1, k2, axis=2).reshape(B*k1*k2)
indices_k2_repeat = torch.repeat_interleave(topk2, k1, axis=2).reshape(B*k1*k2)
indices_k1k2 = torch.stack((indices_k1_repeat, indices_k2_repeat), dim=1)
indices_batch = torch.arange(B).repeat_interleave(k1*k2).unsqueeze(1)
indices_combined = torch.unbind(torch.cat((indices_batch, indices_k1k2), axis=1), axis=1)
print(len(indices_combined[0]), B*k1*k2) # B * k1 *k2
combined_labels = grasps_labels_matrix[indices_combined].reshape(B, k1, k2)