I want to select a subset of a 4-D tensor (attention matrix of a multi-head-attention-block) based on the indices specified by a 3-D tensor B. Tensor A has shape (a, b, c, d) and a tensor B of shape (a, b, g). B contains the indices I want to select from A. Here g < c meaning that I want to select a subset (size g) from the axis c of A. The respective entries from dimension d of A should be kept. The resulting subset of A should therefore have shape (a, b, g, d).

If you know how to do this, I would really appreciate your help. I’ve tried to solve this on my own for hours but I could not figure it out. Thank you for your reply!

Since I messed up the description, this was not quite the solution I was looking for. Luckily I found a solution by now. I also updated the post and here is the solution for anyone in the future who is looking for it:

def select_subset(A, B):
# A.shape = (a, b, c, d) (float)
# B.shape = (a, b, g) (int 0 <= g < c)
# C.shape = (a, b, g, d) (parallely for a, b select vectors d of A at indices g in c.)

a, b, c, d = A.shape
a_, b_, g = B.shape
assert a == a_ and b == b_ and g <= c, "Invalid shapes of A and B!"
indices = B.unsqueeze(-1).expand(a, b, g, d) # shape: (a, b, g, d)
C = torch.gather(A, dim=2, index=indices) # shape: (a, b, g, d)
return C