Torch 2.0 Dynamo Inductor Does not Work for Huggingface Transformers Text Generation Model

Torch 2.0 Dynamo Inductor works for simple encoder-only models like BERT, but not for more complex models like T5 that use .generate function.
Code:

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch._dynamo as torchdynamo
import torch

torchdynamo.config.cache_size_limit = 512
model_name = "t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.generate2 = torchdynamo.optimize("inductor")(model.generate)
torch._dynamo.config.verbose=True
inputs = tokenizer("Generate taxonomy for query: dildo", return_tensors="pt").to('cuda')
# dynamo warm up
with torch.inference_mode():
    outputs = model.generate2(inputs=inputs["input_ids"])

Error:

AttributeError                            Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:323, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    322 try:
--> 323     out_code = transform_code_object(code, transform)
    324     orig_code_map[out_code] = code

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:341, in transform_code_object(code, transformations, safe)
    339 propagate_line_nums(instructions)
--> 341 transformations(instructions, code_options)
    343 fix_vars(instructions, code_options)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:310, in _compile.<locals>.transform(instructions, code_options)
    298 tracer = InstructionTranslator(
    299     instructions,
    300     code,
   (...)
    308     mutated_closure_cell_contents,
    309 )
--> 310 tracer.run()
    311 output = tracer.output

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1692, in InstructionTranslator.run(self)
   1691 _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1692 super().run()

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:538, in InstructionTranslatorBase.run(self)
    534 self.output.push_tx(self)
    535 while (
    536     self.instruction_pointer is not None
    537     and not self.output.should_exit
--> 538     and self.step()
    539 ):
    540     pass

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:501, in InstructionTranslatorBase.step(self)
    500     unimplemented(f"missing: {inst.opname}")
--> 501 getattr(self, inst.opname)(inst)
    503 return inst.opname != "RETURN_VALUE"

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1039, in InstructionTranslatorBase.LOAD_ATTR(self, inst)
   1038 obj = self.pop()
-> 1039 result = BuiltinVariable(getattr).call_function(
   1040     self, [obj, ConstantVariable(inst.argval)], {}
   1041 )
   1042 self.push(result)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py:322, in BuiltinVariable.call_function(self, tx, args, kwargs)
    321 try:
--> 322     result = handler(tx, *args, **kwargs)
    323     if result is not None:

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py:699, in BuiltinVariable.call_getattr(self, tx, obj, name_var, default)
    698 if isinstance(obj, variables.NNModuleVariable):
--> 699     return obj.var_getattr(tx, name).add_options(options)
    700 elif isinstance(obj, variables.TensorVariable) and name == "grad":

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py:131, in NNModuleVariable.var_getattr(self, tx, name)
    130 if object_member:
--> 131     return VariableBuilder(tx, NNModuleSource(source))(subobj)
    132 else:

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:170, in VariableBuilder.__call__(self, value)
    169     return self.tx.output.side_effects[value]
--> 170 return self._wrap(value).clone(**self.options())

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:448, in VariableBuilder._wrap(self, value)
    439     return NumpyVariable(
    440         value,
    441         source=self.source,
   (...)
    446         ),
    447     )
--> 448 elif value in tensor_dunder_fns:
    449     return TorchVariable(
    450         value,
    451         source=self.source,
    452         guards=make_guards(GuardBuilder.FUNCTION_MATCH),
    453     )

File /opt/conda/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:287, in GenerationConfig.__eq__(self, other)
    286 self_dict = self.__dict__.copy()
--> 287 other_dict = other.__dict__.copy()
    288 # ignore metadata

AttributeError: 'method_descriptor' object has no attribute '__dict__'

from user code:
   File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1183, in <graph break in generate>
    if self.generation_config._from_model_config:


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True


The above exception was the direct cause of the following exception:

InternalTorchDynamoError                  Traceback (most recent call last)
Cell In[33], line 3
      1 # dynamo warm up
      2 with torch.inference_mode():
----> 3     outputs = model.generate2(inputs=inputs["input_ids"])

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:211, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    209 dynamic_ctx.__enter__()
    210 try:
--> 211     return fn(*args, **kwargs)
    212 finally:
    213     set_eval_frame(prior)

File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1177, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, **kwargs)
   1106 r"""
   1107 
   1108 Generates sequences of token ids for models with a language modeling head.
   (...)
   1174             - [`~generation.BeamSampleEncoderDecoderOutput`]
   1175 """
   1176 # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
-> 1177 self._validate_model_class()
   1179 # priority: `generation_config` argument > `model.generation_config` (the default generation config)
   1180 if generation_config is None:
   1181     # legacy: users may modify the model configuration to control generation -- update the generation config
   1182     # model attribute accordingly, if it was created from the model config

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:332, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
    329             return hijacked_callback(frame, cache_size, hooks)
    331 with compile_lock:
--> 332     return callback(frame, cache_size, hooks)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:403, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
    401 counters["frames"]["total"] += 1
    402 try:
--> 403     result = inner_convert(frame, cache_size, hooks)
    404     counters["frames"]["ok"] += 1
    405     return result

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:103, in wrap_convert_context.<locals>._fn(*args, **kwargs)
    101 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
    102 try:
--> 103     return fn(*args, **kwargs)
    104 finally:
    105     torch._C._set_grad_enabled(prior_grad_mode)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:261, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
    258 global initial_grad_state
    259 initial_grad_state = torch.is_grad_enabled()
--> 261 return _compile(
    262     frame.f_code,
    263     frame.f_globals,
    264     frame.f_locals,
    265     frame.f_builtins,
    266     compiler_fn,
    267     one_graph,
    268     export,
    269     hooks,
    270     frame,
    271 )

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:153, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
    151     compilation_metrics[key] = []
    152 t0 = time.time()
--> 153 r = func(*args, **kwargs)
    154 time_spent = time.time() - t0
    155 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:393, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    391 except Exception as e:
    392     exception_handler(e, code, frame)
--> 393     raise InternalTorchDynamoError() from e

InternalTorchDynamoError: 

Environment:

Collecting environment information...
PyTorch version: 2.0.0.dev20230124
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.27

Python version: 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-4.14.301-224.520.amzn2.x86_64-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.6.124
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0.dev20230124
[pip3] torchaudio==2.0.0.dev20230124
[pip3] torchdata==0.6.0.dev20230124
[pip3] torchelastic==0.2.2
[pip3] torchtext==0.15.0.dev20230124
[pip3] torchvision==0.15.0.dev20230124
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] numpy 1.23.5 py310hd5efca6_0
[conda] numpy-base 1.23.5 py310h8e6c178_0
[conda] pytorch 2.0.0.dev20230124 py3.10_cuda11.6_cudnn8.3.2_0 pytorch-nightly
[conda] pytorch-cuda 11.6 h867d48c_2 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 2.0.0.dev20230124 py310_cu116 pytorch-nightly
[conda] torchdata 0.6.0.dev20230124 py310 pytorch-nightly
[conda] torchelastic 0.2.2 pypi_0 pypi
[conda] torchtext 0.15.0.dev20230124 py310 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 py310 pytorch-nightly
[conda] torchvision 0.15.0.dev20230124 py310_cu116 pytorch-nightly

Could you create an issue on GitHub so that we could track and fix the issue, please?

Thanks @ptrblck ! Added here Torch Dynamo Internal Error · Issue #93042 · pytorch/pytorch · GitHub