Indexing over a dimension

Is there a loopless native PyTorch equivalent of this op?

def index_dim(tensor, dim, index):
  return torch.stack([tensor[k].select(dim - 1, i) for k, i in enumerate(index.tolist())])

t = torch.rand(2, 4, 3)
i = torch.LongTensor([0, 1])
print(index_dim(t, 1, i).shape)
# torch.Size([2, 3])

In practice I need this to undo padding from a GRU output (I know I could use PackedSequence format, but for now I’d like to try padded format as it’s more compatible with ConvNet models) and get an embedding from the last meaningful timestep for every sequence in the batch.

In the case you have you could do

t[torch.arange(t.size(0)), i]

but I think you’re trying to ask for something more general. How do you want to undo the padding from a GRU output?

For every batch element of output of GRU I’m taking a certain timestep slice that corresponds to the last timestep for that sequence. This discards hidden states corresponding to padding time steps - that’s what I mean by “undoing the padding”

Thanks for your suggestion! In my original post notation it should be the following, right?

def index_dim(tensor, index):
  return tensor[torch.arange(len(tensor)), index]

For every batch element of output of GRU I’m taking a certain timestep slice that corresponds to the last timestep for that sequence. This discards hidden states corresponding to padding time steps - that’s what I mean by “undoing the padding”

Thank you for that explanation, I understand your problem now.

Yes, that is my suggestion in the original post notation.