Need something like gather/select for a multidimensional case

Tried using gather and select for this but couldn’t get them to work here. For simplicity, I’ll just put my code here first:

# hidden is a [B, N, D] transformer hidden state
# eot_inds is a [B] int tensor of indices
# y is meant to be [B, D]  (embeddings of certain positions specified by eot_inds in each sequence from transformer hidden state)
B, N, D = hidden.shape
y = torch.zeros(B, D)
for i in range(B):
    y[i] = hidden[i, eot_inds[i]]

How would I “pytorchify” this so I didn’t need a for loop?

If I understand your use case correctly, this should work:

B, N, D = 2, 3, 4
hidden = torch.randn(B, N, D)
y = torch.zeros(B, D)
eot_inds = torch.randint(0, N, (B,))

for i in range(B):
    y[i] = hidden[i, eot_inds[i]]
    

out = hidden[torch.arange(hidden.size(0)), eot_inds]
print((out == y).all())
> tensor(True)