Compiling vmapped custom op

Hello,
I’m trying to compile a vmapped custom operator and getting an “INTERNAL ASSERT FAILED” error. Here are four test functions that show that native pytorch operators are able to work with torch.compile and vmap, and custom operators work separately with either torch.compile or vmap, but not together.

import torch, os
from torch import Tensor

lib = torch.library.Library('mylib', 'FRAGMENT')

@torch.library.custom_op('mylib::inc_custom', mutates_args=())
def inc_custom(x: Tensor) -> Tensor:
    return x + 1

@torch.library.register_fake('mylib::inc_custom')
def _(x):
    return torch.empty_like(x)

@torch.library.register_vmap('mylib::inc_custom')
def inc_vmap(info, in_dims, x):
    return inc_custom(x), 0

def inc_simple(x: Tensor) -> Tensor:
    return x + 1


f1 = torch.compile( torch.vmap(inc_simple) )
f2 = torch.compile(inc_custom)
f3 = torch.vmap(inc_custom)
f4 = torch.compile(f3)

a = torch.arange(8)
b = a.view(2, 4)

print('inc_simple vmap + compile:')
print( f1(b) )

print('inc_custom compile:')
print( f2(a) )

print('inc_custom vmap:')
print( f3(b) )

print('inc_custom vmap + compile:')
print( f4(b) )

Output:

inc_simple vmap + compile:
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])
inc_custom compile:
tensor([1, 2, 3, 4, 5, 6, 7, 8])
inc_custom vmap:
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])
inc_custom vmap + compile:
Traceback (most recent call last):
  File "/home/alex/python_tests/test2.py", line 42, in <module>
    print( f4(b) )

...
...
...

File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_function mylib.inc_custom.default(*(BatchedTensor(lvl=1, bdim=0, value=
    FakeTensor(..., size=(2, 4), dtype=torch.int64)
),), **{}):
tls_on_entry.has_value() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1729647429097/work/aten/src/ATen/core/PythonFallbackKernel.cpp":49, please report a bug to PyTorch. 

from user code:
   File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_library/custom_ops.py", line 669, in __call__
    return self._opoverload(*args, **kwargs)

Could you check your code using the latest nightly binary? If it’s still failing, could you create an issue on GitHub, please?

1 Like

Thanks a lot, in latest nighty version it’s working.
P.S. after updating to latest version I ran into another strange problem. I created a torch tensor, then converted it to numpy and called PCA.fit_transform from sklearn.decomposition on it. And with relatively small array this function takes veeeery long time to compute. But if I firstly create the same array using just numpy without converting it to torch tensor, it computes almost instantly

I don’t know how exactly you are profiling your code but keep in mind that tensor.nunpy() will synchronize your code if you are using the GPU for previous tensor operations. In such a case, make sure to synchronize the code explicitly before starting and stopping the host timers.

I tried using GPU or not - the problem remained the same. But I’ve already solved it by computing the PCA using torch.linalg.svd

Now found another strange thing with compile and vmap in the last nighty Pytorch version Compile and vmap in custom op with quantile