Compile and vmap in custom op with quantile

Hi, I’m trying to implement vmap rule for quantile and then compile it. I don’t know how to add batching rules directly via c++, so I created custom operator around the torch.quantile function in python and registered vmap on it. Without compiling it works fine, but compiled function returns a tensor with the same elements in different order.
Torch version: 2.6.0.dev20241121+cu124

import torch
from torch import Tensor

lib = torch.library.Library('mylib', 'FRAGMENT')

@torch.library.custom_op('mylib::vquantile', mutates_args=())
def vquantile(x: Tensor, q: Tensor, dim: int = -1) -> Tensor:
    return torch.quantile(x, q, dim)

@torch.library.register_fake('mylib::vquantile')
def _(x, q, dim=-1):
    n = q.numel()
    x = x.index_select(dim, torch.zeros(n, dtype=int) ).squeeze(dim)
    return torch.empty_like(x)

@torch.library.register_vmap('mylib::vquantile')
def quantile_vmap(info, in_dims, x, q, dim=-1):
    x = vquantile(x.movedim(in_dims[0], -1), q, dim % (x.ndim - 1) )
    return x, x.ndim - 1

a = torch.arange(10.).reshape(2, 5)
q = torch.tensor([.2, .5, .8])

f1 = torch.vmap(lambda x: vquantile(x, q, -1))
f2 = torch.compile(f1)

r1 = f1(a)
r2 = f2(a)

print(f'a: \n{a}')
print(f'vmapped quantiles: \n{r1}')
print(f'compiled & vmapped quantiles: \n{r2}')

Output:

a: 
tensor([[0., 1., 2., 3., 4.],
        [5., 6., 7., 8., 9.]])
vmapped quantiles: 
tensor([[0.8000, 2.0000, 3.2000],
        [5.8000, 7.0000, 8.2000]])
compiled & vmapped quantiles: 
tensor([[0.8000, 3.2000, 7.0000],
        [2.0000, 5.8000, 8.2000]])

P.S. It seems that it is working properly only if I set out_dims=-1 when calling vmap, that is:

f1 = torch.vmap(lambda x: vquantile(x, q, -1), out_dims=-1)