Hello,
I’m trying to compile a vmapped custom operator and getting an “INTERNAL ASSERT FAILED” error. Here are four test functions that show that native pytorch operators are able to work with torch.compile and vmap, and custom operators work separately with either torch.compile or vmap, but not together.
import torch, os
from torch import Tensor
lib = torch.library.Library('mylib', 'FRAGMENT')
@torch.library.custom_op('mylib::inc_custom', mutates_args=())
def inc_custom(x: Tensor) -> Tensor:
return x + 1
@torch.library.register_fake('mylib::inc_custom')
def _(x):
return torch.empty_like(x)
@torch.library.register_vmap('mylib::inc_custom')
def inc_vmap(info, in_dims, x):
return inc_custom(x), 0
def inc_simple(x: Tensor) -> Tensor:
return x + 1
f1 = torch.compile( torch.vmap(inc_simple) )
f2 = torch.compile(inc_custom)
f3 = torch.vmap(inc_custom)
f4 = torch.compile(f3)
a = torch.arange(8)
b = a.view(2, 4)
print('inc_simple vmap + compile:')
print( f1(b) )
print('inc_custom compile:')
print( f2(a) )
print('inc_custom vmap:')
print( f3(b) )
print('inc_custom vmap + compile:')
print( f4(b) )
Output:
inc_simple vmap + compile:
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
inc_custom compile:
tensor([1, 2, 3, 4, 5, 6, 7, 8])
inc_custom vmap:
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
inc_custom vmap + compile:
Traceback (most recent call last):
File "/home/alex/python_tests/test2.py", line 42, in <module>
print( f4(b) )
...
...
...
File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_ops.py", line 716, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_function mylib.inc_custom.default(*(BatchedTensor(lvl=1, bdim=0, value=
FakeTensor(..., size=(2, 4), dtype=torch.int64)
),), **{}):
tls_on_entry.has_value() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1729647429097/work/aten/src/ATen/core/PythonFallbackKernel.cpp":49, please report a bug to PyTorch.
from user code:
File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_library/custom_ops.py", line 669, in __call__
return self._opoverload(*args, **kwargs)