Torch.compile + vmap + jacfwd raises TorchRuntimeError on regular python code but not via eval()

I have a really odd runtime error that only appears when calling the function normally, but not through eval(). To replicate, you can run the following code in a jupyter notebook cell or repl.

import torch as th
import torch.func as fc

import os
os.environ["TORCHDYNAMO_VERBOSE"] = "1"

def f(x, u):
    return x*10 + 4*u
        
dfdx = th.compile(th.vmap(fc.jacfwd(f, argnums=0)))

x = th.rand((1, 2))
u = th.rand((1, 2))
dfdx(x, u) # raises: TorchRuntimeError

When doing so, I get the following error:

File ~/path/to/venv/lib/python3.10/site-packages/torch/_dynamo/exc.py:653, in unimplemented(gb_type, context, explanation, hints, from_exc, log_warning, skip_frame)
    651         raise Unsupported(msg, gb_type, skip_frame, real_stack=past_real_stack)
    652     # noqa: GB_REGISTRY
--> 653     raise Unsupported(
    654         msg, gb_type, skip_frame, real_stack=past_real_stack
    655     ) from from_exc
    656 # noqa: GB_REGISTRY
    657 raise Unsupported(msg, gb_type, skip_frame)

TorchRuntimeError: RuntimeError when making fake tensor call
  Explanation: Dynamo failed to run FX node with fake tensors: call_function <built-in function mul>(*(GradTrackingTensor(lvl=3, value=
        BatchedTensor(lvl=1, bdim=0, value=
            FakeTensor(..., size=(1, 2))
        )
    ), 10), **{}): got RuntimeError('InferenceMode::is_enabled() && self.is_inference() INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/VariableMethodStubs.cpp":66, please report a bug to PyTorch. Expected this method to only be reached in inference mode and when all the inputs are inference tensors. You should NOT call this method directly as native::_fw_primal. Please use the dispatcher, i.e., at::_fw_primal. Please file an issue if you come across this error otherwise.')
  Hint: Your code may result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled. You can do this by removing the `torch.compile` call, or by using `torch.compiler.set_stance("force_eager")`. 

  Developer debug context: 

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb4315.html

However, when I ran the following code with the expression is inside eval():

import torch as th
import torch.func as fc

def f2(x, u):
    return eval("x*2 + 4*u", {'x': x, 'u':u})

x = th.rand((1, 2))
u = th.rand((1, 2))
dfdx2 = th.compile(th.vmap(fc.jacfwd(f2, argnums=0)))
dfdx2(x, u) # returns: tensor([[[2., 0.], [0., 2.]]])

The function runs as expected and returns the correct. Additionally, the original code that raised the runtime error will also run normally and generate the correct result if ran after.

If anyone knows what may be causing it or how to fix, please let me know. I am using python 3.10.12, pytorch version 2.12.0+cu132.

Hi!

Disclaimer: I’m not a compilation expert at all

I was able to reproduce your problem.
It seems like a known issue:

But I think your reproducer (the first part, without eval) is much simpler and would be a nice addition to those issues.

However, it’s interesting that the eval version works. Maybe because it’s in an eval, the compiler is forced to run it in eager mode, bypassing the bug.

Additionally, the original code that raised the runtime error will also run normally and generate the correct result if ran after.

This is very strange. Maybe compiling the eval version caches a “compiled” (I put quotes here because maybe the compilation just decides to run it eagerly) version of the function x*10 + 4*u, and then compiling the original function uses this cached version.

Thanks for the feedback and link to existing issues. I hope the developers address it soon. The behavior is really odd.

With respect to things working after the eval version, I am not sure it is because of caching since we can have different expressions like x*2 + 4*u and x*10 + 4*u and it would still work after the eval version ran. Maybe it is caching of the family of the expression.

@ptrblck Do you have any insights on what may be happening?