How to select a subset of a 4-D tensor based on 3-D indices

Hello PyTorch community!

I want to select a subset of a 4-D tensor (attention matrix of a multi-head-attention-block) based on the indices specified by a 3-D tensor B. Tensor A has shape (a, b, c, d) and a tensor B of shape (a, b, g). B contains the indices I want to select from A. Here g < c meaning that I want to select a subset (size g) from the axis c of A. The respective entries from dimension d of A should be kept. The resulting subset of A should therefore have shape (a, b, g, d).

If you know how to do this, I would really appreciate your help. I’ve tried to solve this on my own for hours but I could not figure it out. Thank you for your reply!

You haven’t posted a (slow) reference code, so I cannot verify my approach, but this might work:

a, b, c, d = 2, 3, 4, 5
g = 6

A = torch.randn(a, b, c, d)
B = torch.randint(0, c, (a, b, g))
out = A[torch.arange(a)[:, None, None], torch.arange(b)[:, None], B]
# torch.Size([2, 3, 6, 5]) = [a, b, g, d]
1 Like

Thank you for your reply!

Since I messed up the description, this was not quite the solution I was looking for. Luckily I found a solution by now. I also updated the post and here is the solution for anyone in the future who is looking for it:

def select_subset(A, B):
# A.shape = (a, b, c, d) (float)
# B.shape = (a, b, g) (int 0 <= g < c)
# C.shape = (a, b, g, d) (parallely for a, b select vectors d of A at indices g in c.)

a, b, c, d = A.shape
a_, b_, g = B.shape
assert a == a_ and b == b_ and g <= c, "Invalid shapes of A and B!"

indices = B.unsqueeze(-1).expand(a, b, g, d)  # shape: (a, b, g, d)
C = torch.gather(A, dim=2, index=indices)  # shape: (a, b, g, d)
return C