PyTorch equivalent of tf.gather

So I have an input tensor with shape [16, 1, 125, 256] and a selector tensor with shape [124, 2]. Is there any PyTorch equivalent of tf.gather(input, selector, axis=2)? To begin with, how does Tensorflow handle both tensors despite not having the same number of dimensions (unlike torch.gather wherein they must be equal)? Furthermore, what must dim in torch.gather be to be similar to axis=2 in tf.gather.

For more context, I’m trying to make PyTorch version of this, which defines a frame function that expands signal's axis dimension into frames of frame_length.

1 Like

Could you explain what tf.gather(input, selector, axis=2) is doing when the there is no matching dimension between input and selector? Maybe with a small example of what you are doing?

It gathers slices of select indices (selector) from input with respect to the specified axis. A workaround I found is input[:, :, selector, :] which returns the same output as tf.gather(input, selector, axis=2). However, how do I do the same thing if the shape of input is not known? Here are some more examples:

tf.gather(input, selector, axis=0) is the same as input[selector, ...]
tf.gather(input, selector, axis=1) is the same as input[:, selector, ...]
tf.gather(input, selector, axis=2) is the same as input[..., selector, :]
tf.gather(input, selector, axis=3) is the same as input[..., selector]

1 Like

It’s still not 100% clear to me what this does :smiley: . But I guess that would work:

dim = 2
new_size = inp.size()[:dim] + selector.size() + inp.size()[dim+1:]
out = inp.index_select(selector.view(-1), dim=dim)

It works. Thanks for this! One last question: is there a way to simultaneously slice along all dimensions of a tensor. Say, for example, I have a tensor inp with shape [16, 1, 125, 256]. I then have two lists (this is a simple example)

begin = [0, 0, 0, 0]
end = [16, 1, 125, 256]

that specify the beginning and end indices to slice, respectively. From this, I perform slicing as

out = inp[begin[0]:end[0], begin[1]:end[1], begin[2]:end[2], begin[3]:end[3]]

How do I do this with any tensor with a variable number of dimensions?

The way I would do this is with a serie of narrows (keep in mind that narrow never copies the memory so it is very fast):

out = inp
for dim, (b, e) in enumerate(zip(begin, end)):
   out = out.narrow(dim, b, e-b)

These work as expected. Thanks again for your help!

Hey @albanD, I just came across index_select, looks like it’s a simpler version of torch.gather?
For 2d matrix, index_select can only select same columns for all rows, but gather can select different columns for different rows.

I’m actually trying to use gather, but it failed with some errors, (see Torch.gather() doesn't work during backward())