Piggy-backing on @jnhwkim’s idea of using gather, here is a function that should mimic index_select
for arbitrary dimensional inputs. It just doesn’t work with negative indexing.
# input: B x * x ... x *
# dim: 0 <= scalar
# index: M
def batched_index_select(input, dim, index):
views = [1 if i != dim else -1 for i in xrange(len(input.shape))]
expanse = list(input.shape)
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(input, dim, index)
And here’s a variation if you have different indexes for each batch:
# input: B x * x ... x *
# dim: 0 < scalar
# index: B x M
def batched_index_select(input, dim, index):
views = [input.shape[0]] + \
[1 if i != dim else -1 for i in xrange(1, len(input.shape))]
expanse = list(input.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(input, dim, index)