Batched index_select

How can I use torch.index_select for 3D tensors?
I need some function which will apply usual index_select per batch.
F.e. I want to be able to do something like this:

A = torch.FloatTensor([[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6]], [[7, 8, 9], [7,8, 9]]])
ind = torch.LongTensor([[0, 1], [1, 2], [0, 2]])
torch.index_select(A, 2, ind)

What is the most effective and simple way to do this?
t.y.

6 Likes

I’m not sure if this is the best way to do it, but this works:

torch.cat([ torch.index_select(a, 1, i).unsqueeze(0) for a, i in zip(A, ind) ])

Essentially what this does is apply the regular index_select to each batch-element of A and ind and then concatenate the results together.

6 Likes

yeah, it’s pretty straightforward. I somehow thought that we can do it without “for loops”…I’m probably overextending :slight_smile: thank you.

3 Likes

Is there any solution of batched index_select without using a loop?

This is a bit out of context. But I hope one can figure out what is going here

def get_selected_matrices(probs, references, dim=1, sanity_check=False):
    """
    batched index select
    probs - is a matrix
    references - is index
    dim - is dimention of element of the batch
    """
    #return torch.cat([torch.index_select(a, dim, Variable(LongTensor(i))).unsqueeze(0)\
    #                        for a, i in zip(probs, references)])

    batch_size, seq_len, vocab_size = probs.size()
    references = torch.from_numpy(np.array(references)).long() # batch_size x seq_len
    vocab_extension = torch.arange(0, batch_size).long() * vocab_size # batch_size
    if torch.cuda.is_available():
        references = references.cuda()
        vocab_extension = vocab_extension.cuda()
    references_extended_vocab = (references + vocab_extension.unsqueeze(-1)).view(-1) # batch_size * seq_len
    probs_extended_vocab = torch.transpose(probs, 0, 1).contiguous().view(seq_len, -1) # seq_len x batch_size * vocab_size
    probs_reduced_extended_vocab = torch.index_select(
        probs_extended_vocab, dim, Variable(references_extended_vocab)
    ) # seq_len x batch_size * seq_len
    probs_reduced_vocab = torch.transpose(
        probs_reduced_extended_vocab.view(seq_len, batch_size, seq_len), 0, 1
    ) # batch_size x seq_len x seq_len
    if sanity_check:
        probs_reduced_vocab_loop = torch.cat(
            [torch.index_select(a, dim, Variable(LongTensor(i))).unsqueeze(0) for a, i in zip(probs, references)]
        )
        if not torch.equal(probs_reduced_vocab, probs_reduced_vocab_loop):
            raise AssertionError('got wrong probs with reduced vocab')
            print(probs_reduced_vocab)
            print(probs_reduced_vocab_loop)
    return probs_reduced_vocab

This is just converting two dimension index into one dimension.

1 Like

Is this what you want? gather might be useful for the task

# Batched index_select
def batched_index_select(t, dim, inds):
    dummy = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), t.size(2))
    out = t.gather(dim, dummy) # b x e x f
    return out
11 Likes

Piggy-backing on @jnhwkim’s idea of using gather, here is a function that should mimic index_select for arbitrary dimensional inputs. It just doesn’t work with negative indexing.

# input: B x * x ... x *
# dim: 0 <= scalar
# index: M
def batched_index_select(input, dim, index):
	views = [1 if i != dim else -1 for i in xrange(len(input.shape))]
	expanse = list(input.shape)
	expanse[dim] = -1
	index = index.view(views).expand(expanse)
	return torch.gather(input, dim, index)

And here’s a variation if you have different indexes for each batch:

# input: B x * x ... x *
# dim: 0 < scalar
# index: B x M
def batched_index_select(input, dim, index):
	views = [input.shape[0]] + \
		[1 if i != dim else -1 for i in xrange(1, len(input.shape))]
	expanse = list(input.shape)
	expanse[0] = -1
	expanse[dim] = -1
	index = index.view(views).expand(expanse)
	return torch.gather(input, dim, index)
7 Likes

Hi @jnhwkim,
The expected output is
(0 ,.,.) =
1 2
1 2

(1 ,.,.) =
5 6
5 6

(2 ,.,.) =
7 9
7 9
[torch.FloatTensor of size 3x2x2]

Running your version of batched_index_select results in a different output
(0 ,.,.) =
1 1 1
2 2 2

(1 ,.,.) =
5 5 5
6 6 6

(2 ,.,.) =
7 7 7
9 9 9
[torch.FloatTensor of size 3x2x3]

Maybe a slightly better version is here.

import torch
# input: B x * x ... x *
# dim: 0 <= scalar
# index: M
def batched_index_select(input, dim, index):
    views = [1 if i != dim else -1 for i in range(len(input.shape))]
    expanse = list(input.shape)
    expanse[dim] = -1
    index = index.view(views).expand(expanse)
    # making the first dim of output be B
    return torch.cat(torch.chunk(torch.gather(input, dim, index), chunks=index.shape[0], dim=dim), dim=0)

That is exactly what I wanted, thanks. Since I did not want to rely on the exact batch number I used this variant

def batched_index_select(input, dim, index):
    for ii in range(1, len(input.shape)):
        if ii != dim:
            index = index.unsqueeze(ii)
    expanse = list(input.shape)
    expanse[0] = -1
    expanse[dim] = -1
    index = index.expand(expanse)
    return torch.gather(input, dim, index)
2 Likes

I have a similar task, so maybe someone can help me!

So what I have are two tensors:

  • an indices tensor indices with shape (2, 5, 2), where the last dimensions corresponds to indices in x and y dimension
  • a “value tensor” value with shape (2, 5, 2, 16, 16), where I want the last two dimensions to be selected with x and y indices

To be more concrete, the indices are between 0 and 15 and I want to get an output:

out = value[:, :, :, x_indices, y_indices]

For example the indices for the first item are i_y = indices[0, 0, 0]; i_x = indices[0, 0, 1]. The corresponding output should then be out[0, 0, :, i_y, i_x]. Can anybody help me here? Thanks a lot!

1 Like

Not sure if this is comparable to dashesy’s in terms of generality, but here’s a simplified case + tests for selecting vectors, e.g. via argmin or multinomial:

2 Likes

Thanks, @eacousineau. I was looking for this.