Batched index_select

Piggy-backing on @jnhwkim’s idea of using gather, here is a function that should mimic index_select for arbitrary dimensional inputs. It just doesn’t work with negative indexing.

# input: B x * x ... x *
# dim: 0 <= scalar
# index: M
def batched_index_select(input, dim, index):
	views = [1 if i != dim else -1 for i in xrange(len(input.shape))]
	expanse = list(input.shape)
	expanse[dim] = -1
	index = index.view(views).expand(expanse)
	return torch.gather(input, dim, index)

And here’s a variation if you have different indexes for each batch:

# input: B x * x ... x *
# dim: 0 < scalar
# index: B x M
def batched_index_select(input, dim, index):
	views = [input.shape[0]] + \
		[1 if i != dim else -1 for i in xrange(1, len(input.shape))]
	expanse = list(input.shape)
	expanse[0] = -1
	expanse[dim] = -1
	index = index.view(views).expand(expanse)
	return torch.gather(input, dim, index)
7 Likes