So I have an input tensor with shape [16, 1, 125, 256] and a selector tensor with shape [124, 2]. Is there any PyTorch equivalent of tf.gather(input, selector, axis=2)? To begin with, how does Tensorflow handle both tensors despite not having the same number of dimensions (unlike torch.gather wherein they must be equal)? Furthermore, what must dim in torch.gather be to be similar to axis=2 in tf.gather.
For more context, I’m trying to make PyTorch version of this, which defines a frame function that expands signal's axis dimension into frames of frame_length.
Could you explain what tf.gather(input, selector, axis=2) is doing when the there is no matching dimension between input and selector? Maybe with a small example of what you are doing?
It gathers slices of select indices (selector) from input with respect to the specified axis. A workaround I found is input[:, :, selector, :] which returns the same output as tf.gather(input, selector, axis=2). However, how do I do the same thing if the shape of input is not known? Here are some more examples:
tf.gather(input, selector, axis=0) is the same as input[selector, ...] tf.gather(input, selector, axis=1) is the same as input[:, selector, ...] tf.gather(input, selector, axis=2) is the same as input[..., selector, :] tf.gather(input, selector, axis=3) is the same as input[..., selector]
It works. Thanks for this! One last question: is there a way to simultaneously slice along all dimensions of a tensor. Say, for example, I have a tensor inp with shape [16, 1, 125, 256]. I then have two lists (this is a simple example)
begin = [0, 0, 0, 0] end = [16, 1, 125, 256]
that specify the beginning and end indices to slice, respectively. From this, I perform slicing as
out = inp[begin[0]:end[0], begin[1]:end[1], begin[2]:end[2], begin[3]:end[3]]
How do I do this with any tensor with a variable number of dimensions?
Hey @albanD, I just came across index_select, looks like it’s a simpler version of torch.gather?
For 2d matrix, index_select can only select same columns for all rows, but gather can select different columns for different rows.