Hi, I’m trying to implement vmap rule for quantile and then compile it. I don’t know how to add batching rules directly via c++, so I created custom operator around the torch.quantile function in python and registered vmap on it. Without compiling it works fine, but compiled function returns a tensor with the same elements in different order.
Torch version: 2.6.0.dev20241121+cu124
import torch
from torch import Tensor
lib = torch.library.Library('mylib', 'FRAGMENT')
@torch.library.custom_op('mylib::vquantile', mutates_args=())
def vquantile(x: Tensor, q: Tensor, dim: int = -1) -> Tensor:
return torch.quantile(x, q, dim)
@torch.library.register_fake('mylib::vquantile')
def _(x, q, dim=-1):
n = q.numel()
x = x.index_select(dim, torch.zeros(n, dtype=int) ).squeeze(dim)
return torch.empty_like(x)
@torch.library.register_vmap('mylib::vquantile')
def quantile_vmap(info, in_dims, x, q, dim=-1):
x = vquantile(x.movedim(in_dims[0], -1), q, dim % (x.ndim - 1) )
return x, x.ndim - 1
a = torch.arange(10.).reshape(2, 5)
q = torch.tensor([.2, .5, .8])
f1 = torch.vmap(lambda x: vquantile(x, q, -1))
f2 = torch.compile(f1)
r1 = f1(a)
r2 = f2(a)
print(f'a: \n{a}')
print(f'vmapped quantiles: \n{r1}')
print(f'compiled & vmapped quantiles: \n{r2}')
Output:
a:
tensor([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]])
vmapped quantiles:
tensor([[0.8000, 2.0000, 3.2000],
[5.8000, 7.0000, 8.2000]])
compiled & vmapped quantiles:
tensor([[0.8000, 3.2000, 7.0000],
[2.0000, 5.8000, 8.2000]])
P.S. It seems that it is working properly only if I set out_dims=-1 when calling vmap, that is:
f1 = torch.vmap(lambda x: vquantile(x, q, -1), out_dims=-1)