How can I use torch.index_select for 3D tensors?
I need some function which will apply usual index_select per batch.
F.e. I want to be able to do something like this:
Is this what you want? gather might be useful for the task
# Batched index_select
def batched_index_select(t, dim, inds):
dummy = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), t.size(2))
out = t.gather(dim, dummy) # b x e x f
return out
Piggy-backing on @jnhwkim’s idea of using gather, here is a function that should mimic index_select for arbitrary dimensional inputs. It just doesn’t work with negative indexing.
# input: B x * x ... x *
# dim: 0 <= scalar
# index: M
def batched_index_select(input, dim, index):
views = [1 if i != dim else -1 for i in xrange(len(input.shape))]
expanse = list(input.shape)
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(input, dim, index)
And here’s a variation if you have different indexes for each batch:
# input: B x * x ... x *
# dim: 0 < scalar
# index: B x M
def batched_index_select(input, dim, index):
views = [input.shape[0]] + \
[1 if i != dim else -1 for i in xrange(1, len(input.shape))]
expanse = list(input.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(input, dim, index)
import torch
# input: B x * x ... x *
# dim: 0 <= scalar
# index: M
def batched_index_select(input, dim, index):
views = [1 if i != dim else -1 for i in range(len(input.shape))]
expanse = list(input.shape)
expanse[dim] = -1
index = index.view(views).expand(expanse)
# making the first dim of output be B
return torch.cat(torch.chunk(torch.gather(input, dim, index), chunks=index.shape[0], dim=dim), dim=0)
That is exactly what I wanted, thanks. Since I did not want to rely on the exact batch number I used this variant
def batched_index_select(input, dim, index):
for ii in range(1, len(input.shape)):
if ii != dim:
index = index.unsqueeze(ii)
expanse = list(input.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.expand(expanse)
return torch.gather(input, dim, index)
I have a similar task, so maybe someone can help me!
So what I have are two tensors:
an indices tensor indices with shape (2, 5, 2), where the last dimensions corresponds to indices in x and y dimension
a “value tensor” value with shape (2, 5, 2, 16, 16), where I want the last two dimensions to be selected with x and y indices
To be more concrete, the indices are between 0 and 15 and I want to get an output:
out = value[:, :, :, x_indices, y_indices]
For example the indices for the first item are i_y = indices[0, 0, 0]; i_x = indices[0, 0, 1]. The corresponding output should then be out[0, 0, :, i_y, i_x]. Can anybody help me here? Thanks a lot!
Not sure if this is comparable to dashesy’s in terms of generality, but here’s a simplified case + tests for selecting vectors, e.g. via argmin or multinomial: