What is the best practice for pairwise indexing?
More importantly what is the fastest.
I tried the advanced indexing methods for numpy arrays, but they are not working.
import numpy as np
import torch
y = np.arange(16)
y = y.reshape(2,2,2,2)
#Naive way. This is what I want to do, but for larger dimensions ofcourse
z = []
for idx in range(2):
z.append(y[idx,:,idx,:])
z = np.stack(z)
print(z)
print(z.shape)
#Simultaneously indexing both axes works for numpy
z = y[[0,1],:,[0,1],:]
print(z)
print(z.shape)
#now with numpy arange, easier when dimensions are larger
z = y[np.arange(2),:,np.arange(2),:]
print(z)
print(z.shape)
#now in pytorch
ty = torch.from_numpy(y)
ar = torch.arange(0,2).type(torch.LongTensor)
tz = ty[ar,:,ar,:] # this fails here
print(tz)
print(tz.size())
This is the result I get after executing this small script;
[[[ 0 1]
[ 4 5]]
[[10 11]
[14 15]]]
(2, 2, 2)
[[[ 0 1]
[ 4 5]]
[[10 11]
[14 15]]]
(2, 2, 2)
[[[ 0 1]
[ 4 5]]
[[10 11]
[14 15]]]
(2, 2, 2)
Traceback (most recent call last):
File "test_pytorch_indexing.py", line 32, in <module>
tz = ty[ar,:,ar,:]
TypeError: indexing a tensor with an object of type torch.LongTensor. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument.
What is the fastest way in pytorch to do this kind of indexing?