Torch.compile fft_ifft n should be int not symint

I’m trying to employ a model that acts on variable length sequences of shape (1, C, L) with L varying from batch to batch and batch size is 1.

On pytorch 1.x the variable batch size was resulting in continual slowing down, but I wanted to see if the issue might be resolved with torch.compile, but I’m getting an exception as indicated by the topic title.

Here is a minimal example that recreates the error.

class FftMod(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        L = x.shape[-1]
        x_f = torch.fft.rfft(x, n=2*L)
        return x_f
    
torch.compile(FftMod(), dynamic=True)(torch.randn(1, 10, 128))

which results in

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:1194, in run_node(output_graph, node, args, kwargs, nnmodule)
   1193 if op == "call_function":
-> 1194     return node.target(*args, **kwargs)
   1195 elif op == "call_method":

TypeError: fft_rfft(): argument 'n' must be int, not SymInt

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:1152, in get_fake_value(node, tx)
   1151     with tx.fake_mode, enable_python_dispatcher():
-> 1152         return wrap_fake_exception(
   1153             lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   1154         )
   1155 except Unsupported:

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:808, in wrap_fake_exception(fn)
    807 try:
--> 808     return fn()
    809 except UnsupportedFakeTensorException as e:

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:1153, in get_fake_value.<locals>.<lambda>()
   1151     with tx.fake_mode, enable_python_dispatcher():
   1152         return wrap_fake_exception(
-> 1153             lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   1154         )
   1155 except Unsupported:

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:1206, in run_node(output_graph, node, args, kwargs, nnmodule)
   1205 except Exception as e:
-> 1206     raise RuntimeError(
   1207         f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n{e}\n(scroll up for backtrace)"
   1208     ) from e
   1209 raise AssertionError(op)

RuntimeError: Failed running call_function <built-in function fft_rfft>(*(FakeTensor(FakeTensor(..., device='meta', size=(1, s0, s1)), cpu),), **{'n': 2*s1}):
fft_rfft(): argument 'n' must be int, not SymInt
(scroll up for backtrace)

The above exception was the direct cause of the following exception:

TorchRuntimeError                         Traceback (most recent call last)
Cell In[8], line 1
----> 1 torch.compile(FftMod(), dynamic=True)(torch.randn(1, 10, 128))

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:82, in OptimizedModule.forward(self, *args, **kwargs)
     81 def forward(self, *args, **kwargs):
---> 82     return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:209, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    207 dynamic_ctx.__enter__()
    208 try:
--> 209     return fn(*args, **kwargs)
    210 finally:
    211     set_eval_frame(prior)

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:337, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
    334             return hijacked_callback(frame, cache_size, hooks)
    336 with compile_lock:
--> 337     return callback(frame, cache_size, hooks)

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:404, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
    402 counters["frames"]["total"] += 1
    403 try:
--> 404     result = inner_convert(frame, cache_size, hooks)
    405     counters["frames"]["ok"] += 1
    406     return result

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:104, in wrap_convert_context.<locals>._fn(*args, **kwargs)
    102 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
    103 try:
--> 104     return fn(*args, **kwargs)
    105 finally:
    106     torch._C._set_grad_enabled(prior_grad_mode)

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:262, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
    259 global initial_grad_state
    260 initial_grad_state = torch.is_grad_enabled()
--> 262 return _compile(
    263     frame.f_code,
    264     frame.f_globals,
    265     frame.f_locals,
    266     frame.f_builtins,
    267     compiler_fn,
    268     one_graph,
    269     export,
    270     hooks,
    271     frame,
    272 )

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:163, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
    161     compilation_metrics[key] = []
    162 t0 = time.time()
--> 163 r = func(*args, **kwargs)
    164 time_spent = time.time() - t0
    165 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:324, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    322 for attempt in itertools.count():
    323     try:
--> 324         out_code = transform_code_object(code, transform)
    325         orig_code_map[out_code] = code
    326         break

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:445, in transform_code_object(code, transformations, safe)
    442 instructions = cleaned_instructions(code, safe)
    443 propagate_line_nums(instructions)
--> 445 transformations(instructions, code_options)
    446 return clean_and_assemble_instructions(instructions, keys, code_options)[1]

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:311, in _compile.<locals>.transform(instructions, code_options)
    298 nonlocal output
    299 tracer = InstructionTranslator(
    300     instructions,
    301     code,
   (...)
    309     mutated_closure_cell_contents,
    310 )
--> 311 tracer.run()
    312 output = tracer.output
    313 assert output is not None

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1726, in InstructionTranslator.run(self)
   1724 def run(self):
   1725     _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1726     super().run()

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:576, in InstructionTranslatorBase.run(self)
    571 try:
    572     self.output.push_tx(self)
    573     while (
    574         self.instruction_pointer is not None
    575         and not self.output.should_exit
--> 576         and self.step()
    577     ):
    578         pass
    579 except BackendCompilerFailed:

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:540, in InstructionTranslatorBase.step(self)
    538     if not hasattr(self, inst.opname):
    539         unimplemented(f"missing: {inst.opname}")
--> 540     getattr(self, inst.opname)(inst)
    542     return inst.opname != "RETURN_VALUE"
    543 except BackendCompilerFailed:

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:342, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    340 reason = None
    341 try:
--> 342     return inner_fn(self, inst)
    343 except Unsupported as excp:
    344     if self.has_backedge() and self.should_compile_partial_graph():

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1014, in InstructionTranslatorBase.CALL_FUNCTION_KW(self, inst)
   1012 kwargs = dict(zip(argnames, kwargs_list))
   1013 assert len(kwargs) == len(argnames)
-> 1014 self.call_function(fn, args, kwargs)

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:474, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    469 assert isinstance(kwargs, dict)
    470 assert all(
    471     isinstance(x, VariableTracker)
    472     for x in itertools.chain(args, kwargs.values())
    473 )
--> 474 self.push(fn.call_function(self, args, kwargs))

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py:548, in TorchVariable.call_function(self, tx, args, kwargs)
    544         from torch.fx.experimental.symbolic_shapes import sym_sqrt
    546         fn_ = sym_sqrt
--> 548 tensor_variable = wrap_fx_proxy(
    549     tx=tx,
    550     proxy=tx.output.create_proxy(
    551         "call_function",
    552         fn_,
    553         *proxy_args_kwargs(args, kwargs),
    554     ),
    555     **options,
    556 )
    558 if "out" in kwargs and not (
    559     isinstance(kwargs["out"], variables.ConstantVariable)
    560     and kwargs["out"].as_python_constant() is None
   (...)
    563     # torch.sigmoid mutate the tensors in the out field. Track such
    564     # tensors and rewrite the symbolic locals.
    565     if isinstance(tensor_variable, TupleVariable):

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:754, in wrap_fx_proxy(tx, proxy, example_value, **options)
    753 def wrap_fx_proxy(tx, proxy, example_value=None, **options):
--> 754     return wrap_fx_proxy_cls(
    755         target_cls=TensorVariable,
    756         tx=tx,
    757         proxy=proxy,
    758         example_value=example_value,
    759         **options,
    760     )

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:789, in wrap_fx_proxy_cls(target_cls, tx, proxy, example_value, ignore_subclass, **options)
    787 with preserve_rng_state():
    788     if example_value is None:
--> 789         example_value = get_fake_value(proxy.node, tx)
    791     # Handle recursive calls here
    792     elif isinstance(example_value, FakeTensor):

File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:1173, in get_fake_value(node, tx)
   1169 elif isinstance(
   1170     cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
   1171 ):
   1172     unimplemented("guard on data-dependent symbolic int/float")
-> 1173 raise TorchRuntimeError() from e

TorchRuntimeError: 

from user code:
   File "/tmp/ipykernel_216319/1876313449.py", line 8, in forward
    x_f = torch.fft.rfft(x, n=2*L)

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Is there a workaround for this in 2.0.1? I know the dynamic compiler is still in early stages.

@ezyang might be interested in this, I tried out the code snippet on nightlies and still seeing the same issue but i do remember some fft related code being checked in recently

1 Like

Hmm actually there’s out there’s a guide to solve these problems, I’ll try and take a look The dynamic shapes manual - Google Docs

1 Like

Thanks for your replies. The document you linked does seem like it contains the shape of the solution, but I haven’t been able to implement it. In particular I’m not entirely sure where the changes would need to be made.

I tried changing int to SymInt in files where the pattern
fft_rfft(Tensor self, int? n=None, int dim=-1, str? norm=None) -> Tensor
appeared as seemed to be indicated in the “How to SymInt’ify an operator schema” section, including

...site-packages/torch/include/ATen/ops/fft_rfft.h,
...site-packages/torch/include/ATen/RegistrationDeclarations.h and
...site-packages/torch/include/ATen/ops/fft_rfft_ops.h

but there was no effect on the stack trace.

Since torch.fft.__init__.py seems to just be a wrapper around various torch._C code I’m guessing that some changes need to be made to the C++ definitions as well, but I think I’m out of my depth in the “How to SymInt’ify C++ operator” section of the guide.