Efficient implementation using smart indexing

Hello, everyone!
Is it possible to implement following function in vectorized fashion, without for cycles, using advanced indexing?

def dumb_foo(x, permutation):
    assert x.ndimension() == permutation.ndimension()
    ret = torch.zeros_like(x)
    if x.ndimension() == 1:
        ret = x[permutation]
    elif x.ndimension() == 2:
        for i in range(x.size(0)):
            ret[i] = x[permutation[i]]
    elif x.ndimension() == 3:
        for i in range(x.size(0)):
            for j in range(x.size(1)):
                ret[i, j] = x[i, j, permutation[i, j]]
    else:
        ValueError("Only 3 dimensions maximum")
    return ret

Now I end up with something like this. Can it be implemented in more efficient way?

def smart_foo(x, permutation):
    assert x.ndimension() == permutation.ndimension()
    if x.ndimension() == 1:
        ret = x[permutation]
    elif x.ndimension() == 2:
        d1, d2 = x.size()
        ret = x[
            torch.arange(d1).unsqueeze(1).repeat((1, d2)).flatten(),
            permutation.flatten()
        ].view(d1, d2)
    elif x.ndimension() == 3:
        d1, d2, d3 = x.size()
        ret = x[
            torch.arange(d1).unsqueeze(1).repeat((1, d2 * d3)).flatten(),
            torch.arange(d2).unsqueeze(1).repeat((1, d3)).flatten().unsqueeze(0).repeat((1, d1)).flatten(),
            permutation.flatten()
        ].view(d1, d2, d3)
    else:
        ValueError("Only 3 dimensions maximum")
    return ret