# Efficient implementation using smart indexing

Hello, everyone!
Is it possible to implement following function in vectorized fashion, without for cycles, using advanced indexing?

``````def dumb_foo(x, permutation):
assert x.ndimension() == permutation.ndimension()
ret = torch.zeros_like(x)
if x.ndimension() == 1:
ret = x[permutation]
elif x.ndimension() == 2:
for i in range(x.size(0)):
ret[i] = x[permutation[i]]
elif x.ndimension() == 3:
for i in range(x.size(0)):
for j in range(x.size(1)):
ret[i, j] = x[i, j, permutation[i, j]]
else:
ValueError("Only 3 dimensions maximum")
return ret
``````

Now I end up with something like this. Can it be implemented in more efficient way?

``````def smart_foo(x, permutation):
assert x.ndimension() == permutation.ndimension()
if x.ndimension() == 1:
ret = x[permutation]
elif x.ndimension() == 2:
d1, d2 = x.size()
ret = x[
torch.arange(d1).unsqueeze(1).repeat((1, d2)).flatten(),
permutation.flatten()
].view(d1, d2)
elif x.ndimension() == 3:
d1, d2, d3 = x.size()
ret = x[
torch.arange(d1).unsqueeze(1).repeat((1, d2 * d3)).flatten(),
torch.arange(d2).unsqueeze(1).repeat((1, d3)).flatten().unsqueeze(0).repeat((1, d1)).flatten(),
permutation.flatten()
].view(d1, d2, d3)
else:
ValueError("Only 3 dimensions maximum")
return ret
``````