Say I want to index a tensor along an axis, but I don’t know a priori which axis. It could be:
tensor[:, idx, :]
Or it could be:
tensor[idx, :, :]
Or:
tensor[:, :, idx]
In these examples, the tensor has 3 possible axes. Is there a way to do this kind of indexing where both the axis and the index are variable?
You can use tensor.index_select
as seen here:
x = torch.randn(10, 10, 10)
idx = torch.randint(0, 10, (2,))
ref1 = x[idx]
ref2 = x[:, idx]
ref3 = x[:, :, idx]
out1 = x.index_select(0, idx)
out2 = x.index_select(1, idx)
out3 = x.index_select(2, idx)
print((ref1 == out1).all())
# tensor(True)
print((ref2 == out2).all())
# tensor(True)
print((ref3 == out3).all())
# tensor(True)