Question about torch.gather with 3+ dimensions

Hi,

I have 16 threads and for every thread an 8x5 matrix. For every thread, only one of those 8 5x1 vectors interests me and the rest can be discarded.

I’ve tried to do:

action_prob.gather(1, self.current_options.view(-1, 1))

Where action_prob is a 16x8x5 tensor and current_options is a 16x1 tensor. I’ll attach their values at the bottom.

I’m getting this error:
RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.c:581

It seems to me entirely clear what I’m trying to achieve and I don’t understand why there is even a dimension issue. Assuming for example current_options[0] has a value of 2, and as can be seen below, the first matrix is:

[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]]

Why won’t it gather this line: [ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108] ? And do the same for the rest of the 16 matrices, so that in the end I’ll have a 16x5 tensor.

Thanks for the help.

For what it’s worth:

action_prob:

tensor([[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]]])

self.current_options:

tensor([[ 5],
[ 0],
[ 2],
[ 6],
[ 3],
[ 4],
[ 1],
[ 5],
[ 6],
[ 5],
[ 2],
[ 6],
[ 7],
[ 7],
[ 7],
[ 1]])

The docs for gather describe how its indexing works: https://pytorch.org/docs/stable/torch.html#torch.gather.

In particular, it sounds like you want something like:

out[i][j][k] = input[i][indices[i]][k]

We can massage this a little into the form that torch.gather wants, which is out[i][j][k] = input[i][index[i][j][k]][k].

What happens if you do the following:

indices = indices.unsqueeze(-1).unsqueeze(-1)
# indices is now size (16, 1, 1). Let's turn it into the same size as action_prop
indices = indices.expand_as(action_prob)
action_prob.gather(1, indices)
1 Like

An easier alternative to using gather is using python indexing.

Let’s say you have your input (size (16, 8, 5)), and indices of size (16, 1) that contain a number in the range [0, 8). Then one thing you can do is:

count = torch.arange(16)  # get numbers from 0..15
indices = indices.squeeze()  # indices is now size `(16,)`
input[count, indices, :]  # gives an output of size (16, 5).
2 Likes

Thanks Richard, with a few adjustments your solution worked.