So the output of my network looks like this:
output = tensor([[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.0315, -0.1837],
[ 0.0318, -0.1828],
[ 0.0322, -0.1822],
[ 0.0324, -0.1819],
[ 0.0327, -0.1817],
[ 0.0328, -0.1815],
[ 0.0330, -0.1815],
[ 0.0331, -0.1814],
[ 0.0332, -0.1814],
[ 0.0333, -0.1814],
[ 0.0333, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]]])
Which is a shape of [8, 24, 2]
Now 8 is my batch size. And i would like to get a data point from every batch, at the following locations:
index = tensor([24, 10, 3, 3, 1, 1, 1, 0])
So the 24th value from the first batch, the 10th value from the second batch, and so on.
Now i have problems figuring out the syntax.
I’ve tried
torch.gather(output, 0, index)
But it keeps telling me, that my dimensions don’t match.
And trying
output[ : ,index]
Just gets me the values at all the indexes for each batch.
What would be the correct syntax here, to get these values?