Indexing variable axis

Say I want to index a tensor along an axis, but I don’t know a priori which axis. It could be:

tensor[:, idx, :]

Or it could be:

tensor[idx, :, :]


tensor[:, :, idx]

In these examples, the tensor has 3 possible axes. Is there a way to do this kind of indexing where both the axis and the index are variable?

You can use tensor.index_select as seen here:

x = torch.randn(10, 10, 10)
idx = torch.randint(0, 10, (2,))

ref1 = x[idx]
ref2 = x[:, idx]
ref3 = x[:, :, idx]

out1 = x.index_select(0, idx)
out2 = x.index_select(1, idx)
out3 = x.index_select(2, idx)

print((ref1 == out1).all())
# tensor(True)
print((ref2 == out2).all())
# tensor(True)
print((ref3 == out3).all())
# tensor(True)