Indexing tensor with tensor : without for loop

I have a tensor A with shape of (32,200,50,2), and tensor B with shape of (32,100,100,128).
Tensor A is integer tensor. The lower bound of tensor A is 0 and upper bound of tensorA is 99. Therefore tensor A can be used as index tensor for tensor B.
From tensor A and tensor B, I want to generate tensor C with shape of (32,200,50,128).
The code for get tensor C is below.

C=torch.zeros(*A.shape[:-1], B.shape[-1])
for b in range( A.shape[0]):
       for i in range(A.shape[1]):
              for j in range(A.shape[2]):
                     C[b, i, j, :]= B[b, A[b,i,j,0], A[b,i,j,1], :] 

How can I do the same job without using ugly for loop?

Direct indexing should work:

A = torch.randint(0, 100, (32, 200, 50, 2))
B = torch.randn(32, 100, 100, 128)

C=torch.zeros(*A.shape[:-1], B.shape[-1])
for b in range( A.shape[0]):
       for i in range(A.shape[1]):
              for j in range(A.shape[2]):
                     C[b, i, j, :]= B[b, A[b,i,j,0], A[b,i,j,1], :] 
    

out = B[torch.arange(B.size(0))[:, None, None], A[:, :, :, 0], A[:, :, :, 1]]
print((out == C).all())
# tensor(True)
2 Likes