Hi,
I have 16 threads and for every thread an 8x5 matrix. For every thread, only one of those 8 5x1 vectors interests me and the rest can be discarded.
I’ve tried to do:
action_prob.gather(1, self.current_options.view(-1, 1))
Where action_prob is a 16x8x5 tensor and current_options is a 16x1 tensor. I’ll attach their values at the bottom.
I’m getting this error:
RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.c:581
It seems to me entirely clear what I’m trying to achieve and I don’t understand why there is even a dimension issue. Assuming for example current_options[0] has a value of 2, and as can be seen below, the first matrix is:
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]]
Why won’t it gather this line: [ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108] ? And do the same for the rest of the 16 matrices, so that in the end I’ll have a 16x5 tensor.
Thanks for the help.
For what it’s worth:
action_prob:
tensor([[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]],
[[ 0.1783, 0.2808, 0.2792, 0.1433, 0.1184],
[ 0.1529, 0.1557, 0.1721, 0.1805, 0.3387],
[ 0.1290, 0.2222, 0.1107, 0.3273, 0.2108],
[ 0.2075, 0.2130, 0.1273, 0.1836, 0.2686],
[ 0.1478, 0.2804, 0.1561, 0.1057, 0.3100],
[ 0.2466, 0.1288, 0.1199, 0.3007, 0.2041],
[ 0.1519, 0.2347, 0.1232, 0.3820, 0.1081],
[ 0.1655, 0.1894, 0.3634, 0.1587, 0.1230]]])
self.current_options:
tensor([[ 5],
[ 0],
[ 2],
[ 6],
[ 3],
[ 4],
[ 1],
[ 5],
[ 6],
[ 5],
[ 2],
[ 6],
[ 7],
[ 7],
[ 7],
[ 1]])