In numpy, we can easily do this 2-D indexing on a 2-D tensor:
data = np.array([[0,1,2,3,4,5,6,7,8,9],[10,11,12,13,14,15,16,17,18,19]]).T
print 'data'
print data
print
index = np.array([[0,1],[3,4],[7,9]])
print 'index'
print index
print
print 'selected data'
print data[index]
print
print 'summed data'
print data[index].sum(axis=1)
The logic is this:
data
is a 2-D tensor. And for each index (each row from index
) I want to slice multiple rows from data
, then sum up for each group of indexed rows as final matrix
Below is the output:
data
[[ 0 10]
[ 1 11]
[ 2 12]
[ 3 13]
[ 4 14]
[ 5 15]
[ 6 16]
[ 7 17]
[ 8 18]
[ 9 19]]
index
[[0 1]
[3 4]
[7 9]]
selected data
[[[ 0 10]
[ 1 11]]
[[ 3 13]
[ 4 14]]
[[ 7 17]
[ 9 19]]]
summed data
[[ 1 21]
[ 7 27]
[16 36]]
Can anyone help implement similar ideas in pytorch in a faster way?
Now what I did is simply use a loop plus index_select
, which is too slow.
Appreciate any helpful advice.