The following indexing
t = torch.rand(50, 2, 20, 20, ...) # suppose to have n + 2 dimensions
ids = torch.meshgrid(*[torch.arange(10) for _ in range(n)])
print(t[:, :, ids])
with n = 2
gives this error
TypeError: only integer tensors of a single element can be converted to an index
I am searching for a working solution with a generic n
. The same indexing can be instead performed with NumPy.
Is there a way to perform it?