Batch indexing for rnn

So the output of my network looks like this:

output = tensor([[[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.0410, -0.2234],
     [ 0.0362, -0.2111],
     [ 0.0333, -0.2018],
     [ 0.0318, -0.1951],
     [ 0.0311, -0.1904],
     [ 0.0310, -0.1873],
     [ 0.0312, -0.1851],
     [ 0.0315, -0.1837],
     [ 0.0318, -0.1828],
     [ 0.0322, -0.1822],
     [ 0.0324, -0.1819],
     [ 0.0327, -0.1817],
     [ 0.0328, -0.1815],
     [ 0.0330, -0.1815],
     [ 0.0331, -0.1814],
     [ 0.0332, -0.1814],
     [ 0.0333, -0.1814],
     [ 0.0333, -0.1814],
     [ 0.0334, -0.1814],
     [ 0.0334, -0.1814],
     [ 0.0334, -0.1814]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.0410, -0.2234],
     [ 0.0362, -0.2111],
     [ 0.0333, -0.2018],
     [ 0.0318, -0.1951],
     [ 0.0311, -0.1904],
     [ 0.0310, -0.1873],
     [ 0.0312, -0.1851],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]]])

Which is a shape of [8, 24, 2]
Now 8 is my batch size. And i would like to get a data point from every batch, at the following locations:

index = tensor([24, 10,  3,  3,  1,  1,  1,  0])

So the 24th value from the first batch, the 10th value from the second batch, and so on.

Now i have problems figuring out the syntax.
I’ve tried

torch.gather(output, 0, index)

But it keeps telling me, that my dimensions don’t match.
And trying

output[ : ,index]

Just gets me the values at all the indexes for each batch.
What would be the correct syntax here, to get these values?

output[ torch.arange(batch_size) ,index] should work, as well as output[index_2d] ([[0,24],[1,10],…])

1 Like