Is there a loopless native PyTorch equivalent of this op?
def index_dim(tensor, dim, index):
return torch.stack([tensor[k].select(dim - 1, i) for k, i in enumerate(index.tolist())])
t = torch.rand(2, 4, 3)
i = torch.LongTensor([0, 1])
print(index_dim(t, 1, i).shape)
# torch.Size([2, 3])
In practice I need this to undo padding from a GRU output (I know I could use PackedSequence format, but for now I’d like to try padded format as it’s more compatible with ConvNet models) and get an embedding from the last meaningful timestep for every sequence in the batch.
For every batch element of output of GRU I’m taking a certain timestep slice that corresponds to the last timestep for that sequence. This discards hidden states corresponding to padding time steps - that’s what I mean by “undoing the padding”
Thanks for your suggestion! In my original post notation it should be the following, right?
For every batch element of output of GRU I’m taking a certain timestep slice that corresponds to the last timestep for that sequence. This discards hidden states corresponding to padding time steps - that’s what I mean by “undoing the padding”
Thank you for that explanation, I understand your problem now.
Yes, that is my suggestion in the original post notation.