I’m trying to employ a model that acts on variable length sequences of shape (1, C, L) with L varying from batch to batch and batch size is 1.
On pytorch 1.x the variable batch size was resulting in continual slowing down, but I wanted to see if the issue might be resolved with torch.compile, but I’m getting an exception as indicated by the topic title.
Here is a minimal example that recreates the error.
class FftMod(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
L = x.shape[-1]
x_f = torch.fft.rfft(x, n=2*L)
return x_f
torch.compile(FftMod(), dynamic=True)(torch.randn(1, 10, 128))
which results in
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:1194, in run_node(output_graph, node, args, kwargs, nnmodule)
1193 if op == "call_function":
-> 1194 return node.target(*args, **kwargs)
1195 elif op == "call_method":
TypeError: fft_rfft(): argument 'n' must be int, not SymInt
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:1152, in get_fake_value(node, tx)
1151 with tx.fake_mode, enable_python_dispatcher():
-> 1152 return wrap_fake_exception(
1153 lambda: run_node(tx.output, node, args, kwargs, nnmodule)
1154 )
1155 except Unsupported:
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:808, in wrap_fake_exception(fn)
807 try:
--> 808 return fn()
809 except UnsupportedFakeTensorException as e:
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:1153, in get_fake_value.<locals>.<lambda>()
1151 with tx.fake_mode, enable_python_dispatcher():
1152 return wrap_fake_exception(
-> 1153 lambda: run_node(tx.output, node, args, kwargs, nnmodule)
1154 )
1155 except Unsupported:
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:1206, in run_node(output_graph, node, args, kwargs, nnmodule)
1205 except Exception as e:
-> 1206 raise RuntimeError(
1207 f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n{e}\n(scroll up for backtrace)"
1208 ) from e
1209 raise AssertionError(op)
RuntimeError: Failed running call_function <built-in function fft_rfft>(*(FakeTensor(FakeTensor(..., device='meta', size=(1, s0, s1)), cpu),), **{'n': 2*s1}):
fft_rfft(): argument 'n' must be int, not SymInt
(scroll up for backtrace)
The above exception was the direct cause of the following exception:
TorchRuntimeError Traceback (most recent call last)
Cell In[8], line 1
----> 1 torch.compile(FftMod(), dynamic=True)(torch.randn(1, 10, 128))
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:82, in OptimizedModule.forward(self, *args, **kwargs)
81 def forward(self, *args, **kwargs):
---> 82 return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:209, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
207 dynamic_ctx.__enter__()
208 try:
--> 209 return fn(*args, **kwargs)
210 finally:
211 set_eval_frame(prior)
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:337, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
334 return hijacked_callback(frame, cache_size, hooks)
336 with compile_lock:
--> 337 return callback(frame, cache_size, hooks)
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:404, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
402 counters["frames"]["total"] += 1
403 try:
--> 404 result = inner_convert(frame, cache_size, hooks)
405 counters["frames"]["ok"] += 1
406 return result
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:104, in wrap_convert_context.<locals>._fn(*args, **kwargs)
102 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
103 try:
--> 104 return fn(*args, **kwargs)
105 finally:
106 torch._C._set_grad_enabled(prior_grad_mode)
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:262, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
259 global initial_grad_state
260 initial_grad_state = torch.is_grad_enabled()
--> 262 return _compile(
263 frame.f_code,
264 frame.f_globals,
265 frame.f_locals,
266 frame.f_builtins,
267 compiler_fn,
268 one_graph,
269 export,
270 hooks,
271 frame,
272 )
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:163, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
161 compilation_metrics[key] = []
162 t0 = time.time()
--> 163 r = func(*args, **kwargs)
164 time_spent = time.time() - t0
165 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:324, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
322 for attempt in itertools.count():
323 try:
--> 324 out_code = transform_code_object(code, transform)
325 orig_code_map[out_code] = code
326 break
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:445, in transform_code_object(code, transformations, safe)
442 instructions = cleaned_instructions(code, safe)
443 propagate_line_nums(instructions)
--> 445 transformations(instructions, code_options)
446 return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:311, in _compile.<locals>.transform(instructions, code_options)
298 nonlocal output
299 tracer = InstructionTranslator(
300 instructions,
301 code,
(...)
309 mutated_closure_cell_contents,
310 )
--> 311 tracer.run()
312 output = tracer.output
313 assert output is not None
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1726, in InstructionTranslator.run(self)
1724 def run(self):
1725 _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1726 super().run()
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:576, in InstructionTranslatorBase.run(self)
571 try:
572 self.output.push_tx(self)
573 while (
574 self.instruction_pointer is not None
575 and not self.output.should_exit
--> 576 and self.step()
577 ):
578 pass
579 except BackendCompilerFailed:
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:540, in InstructionTranslatorBase.step(self)
538 if not hasattr(self, inst.opname):
539 unimplemented(f"missing: {inst.opname}")
--> 540 getattr(self, inst.opname)(inst)
542 return inst.opname != "RETURN_VALUE"
543 except BackendCompilerFailed:
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:342, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
340 reason = None
341 try:
--> 342 return inner_fn(self, inst)
343 except Unsupported as excp:
344 if self.has_backedge() and self.should_compile_partial_graph():
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1014, in InstructionTranslatorBase.CALL_FUNCTION_KW(self, inst)
1012 kwargs = dict(zip(argnames, kwargs_list))
1013 assert len(kwargs) == len(argnames)
-> 1014 self.call_function(fn, args, kwargs)
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:474, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
469 assert isinstance(kwargs, dict)
470 assert all(
471 isinstance(x, VariableTracker)
472 for x in itertools.chain(args, kwargs.values())
473 )
--> 474 self.push(fn.call_function(self, args, kwargs))
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py:548, in TorchVariable.call_function(self, tx, args, kwargs)
544 from torch.fx.experimental.symbolic_shapes import sym_sqrt
546 fn_ = sym_sqrt
--> 548 tensor_variable = wrap_fx_proxy(
549 tx=tx,
550 proxy=tx.output.create_proxy(
551 "call_function",
552 fn_,
553 *proxy_args_kwargs(args, kwargs),
554 ),
555 **options,
556 )
558 if "out" in kwargs and not (
559 isinstance(kwargs["out"], variables.ConstantVariable)
560 and kwargs["out"].as_python_constant() is None
(...)
563 # torch.sigmoid mutate the tensors in the out field. Track such
564 # tensors and rewrite the symbolic locals.
565 if isinstance(tensor_variable, TupleVariable):
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:754, in wrap_fx_proxy(tx, proxy, example_value, **options)
753 def wrap_fx_proxy(tx, proxy, example_value=None, **options):
--> 754 return wrap_fx_proxy_cls(
755 target_cls=TensorVariable,
756 tx=tx,
757 proxy=proxy,
758 example_value=example_value,
759 **options,
760 )
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:789, in wrap_fx_proxy_cls(target_cls, tx, proxy, example_value, ignore_subclass, **options)
787 with preserve_rng_state():
788 if example_value is None:
--> 789 example_value = get_fake_value(proxy.node, tx)
791 # Handle recursive calls here
792 elif isinstance(example_value, FakeTensor):
File ~/.conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py:1173, in get_fake_value(node, tx)
1169 elif isinstance(
1170 cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
1171 ):
1172 unimplemented("guard on data-dependent symbolic int/float")
-> 1173 raise TorchRuntimeError() from e
TorchRuntimeError:
from user code:
File "/tmp/ipykernel_216319/1876313449.py", line 8, in forward
x_f = torch.fft.rfft(x, n=2*L)
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
Is there a workaround for this in 2.0.1? I know the dynamic compiler is still in early stages.