Tried using gather and select for this but couldn’t get them to work here. For simplicity, I’ll just put my code here first:
# hidden is a [B, N, D] transformer hidden state
# eot_inds is a [B] int tensor of indices
# y is meant to be [B, D] (embeddings of certain positions specified by eot_inds in each sequence from transformer hidden state)
B, N, D = hidden.shape
y = torch.zeros(B, D)
for i in range(B):
y[i] = hidden[i, eot_inds[i]]
How would I “pytorchify” this so I didn’t need a for loop?