Pytorch 2.0 - OSError: source code not available

I have the following custom class that Pytorch 2.0 doesn’t like when I run the training loop of my compiled VAE network:

class MaskedSequential(nn.Sequential):
    """
    Same as nn.Sequential, but allows to pass the masks tensor in cascading manner
    just as would do a nn.Sequential with the input tensor.
    """
    def forward(self, *inputs):
        x, masks = inputs
        for module in self._modules.values():
            if isinstance(module, ResConvBlock) or isinstance(module, ResConvTransposeBlock):
                x = module(x, masks)
            else:
                x = module(x)
        return x

I get the following error:

OSError                                   Traceback (most recent call last)
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:324, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    323 try:
--> 324     out_code = transform_code_object(code, transform)
    325     orig_code_map[out_code] = code

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:445, in transform_code_object(code, transformations, safe)
    443 propagate_line_nums(instructions)
--> 445 transformations(instructions, code_options)
    446 return clean_and_assemble_instructions(instructions, keys, code_options)[1]

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:311, in _compile.<locals>.transform(instructions, code_options)
    299 tracer = InstructionTranslator(
    300     instructions,
    301     code,
   (...)
    309     mutated_closure_cell_contents,
    310 )
--> 311 tracer.run()
    312 output = tracer.output

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1726, in InstructionTranslator.run(self)
   1725 _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1726 super().run()

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:576, in InstructionTranslatorBase.run(self)
    572 self.output.push_tx(self)
    573 while (
    574     self.instruction_pointer is not None
    575     and not self.output.should_exit
--> 576     and self.step()
    577 ):
    578     pass

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:540, in InstructionTranslatorBase.step(self)
    539     unimplemented(f"missing: {inst.opname}")
--> 540 getattr(self, inst.opname)(inst)
    542 return inst.opname != "RETURN_VALUE"

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:342, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    341 try:
--> 342     return inner_fn(self, inst)
    343 except Unsupported as excp:

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:965, in InstructionTranslatorBase.CALL_FUNCTION(self, inst)
    964 fn = self.pop()
--> 965 self.call_function(fn, args, {})

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:474, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    470 assert all(
    471     isinstance(x, VariableTracker)
    472     for x in itertools.chain(args, kwargs.values())
    473 )
--> 474 self.push(fn.call_function(self, args, kwargs))

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py:244, in NNModuleVariable.call_function(self, tx, args, kwargs)
    243 options["source"] = fn_source
--> 244 return tx.inline_user_function_return(
    245     variables.UserFunctionVariable(fn, **options),
    246     args,
    247     kwargs,
    248 )

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:510, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
    509 try:
--> 510     result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
    511     self.output.guards.update(fn.guards)

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1806, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
   1805 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 1806     return cls.inline_call_(parent, func, args, kwargs)

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1862, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
   1861 try:
-> 1862     tracer.run()
   1863 except exc.SkipFrame as e:

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:576, in InstructionTranslatorBase.run(self)
    572 self.output.push_tx(self)
    573 while (
    574     self.instruction_pointer is not None
    575     and not self.output.should_exit
--> 576     and self.step()
    577 ):
    578     pass

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:540, in InstructionTranslatorBase.step(self)
    539     unimplemented(f"missing: {inst.opname}")
--> 540 getattr(self, inst.opname)(inst)
    542 return inst.opname != "RETURN_VALUE"

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:688, in InstructionTranslatorBase.LOAD_GLOBAL(self, inst)
    687 source = self.get_global_source(name)
--> 688 self.push(VariableBuilder(self, source)(value))

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:172, in VariableBuilder.__call__(self, value)
    171     return self.tx.output.side_effects[value]
--> 172 return self._wrap(value).clone(**self.options())

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:454, in VariableBuilder._wrap(self, value)
    447     return TorchVariable(
    448         value,
    449         source=self.source,
    450         guards=make_guards(GuardBuilder.FUNCTION_MATCH),
    451     )
    452 elif (
    453     istype(value, (type, types.FunctionType))
--> 454     and skipfiles.check(getfile(value), allow_torch=True)
    455     and not inspect.getattr_static(value, "_torchdynamo_inline", False)
    456 ):
    457     return SkipFilesVariable(
    458         value,
    459         source=self.source,
    460         guards=make_guards(GuardBuilder.FUNCTION_MATCH),
    461     )

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/utils.py:586, in getfile(obj)
    585 try:
--> 586     return inspect.getfile(obj)
    587 except TypeError:

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/package/package_importer.py:691, in _patched_getfile(object)
    690         return _package_imported_modules[object.__module__].__file__
--> 691 return _orig_getfile(object)

File ~/mambaforge/envs/vae/lib/python3.10/inspect.py:785, in getfile(object)
    784     if object.__module__ == '__main__':
--> 785         raise OSError('source code not available')
    786 raise TypeError('{!r} is a built-in class'.format(object))

OSError: source code not available

from user code:
   File "/tmp/ipykernel_3332697/3053697966.py", line 380, in forward
    x = self.encoder(x, masks)
  File "/tmp/ipykernel_3332697/3053697966.py", line 292, in forward
    if isinstance(module, ResConvBlock) or isinstance(module, ResConvTransposeBlock):

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

The above exception was the direct cause of the following exception:

InternalTorchDynamoError                  Traceback (most recent call last)
Cell In[3], line 22
     20     images = images.to(device)
     21     masks = masks.to(device)
---> 22     images_pred, mu, log_var = model(images, masks)
     23     loss = MaskedMSELoss()(images_pred, images, masks, mu, log_var)
     24 # Exits autocast before backward().
     25 # Backward passes under autocast are not recommended.
     26 # Backward ops run in the same dtype autocast chose for corresponding forward ops.

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:82, in OptimizedModule.forward(self, *args, **kwargs)
     81 def forward(self, *args, **kwargs):
---> 82     return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:209, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    207 dynamic_ctx.__enter__()
    208 try:
--> 209     return fn(*args, **kwargs)
    210 finally:
    211     set_eval_frame(prior)

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:337, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
    334             return hijacked_callback(frame, cache_size, hooks)
    336 with compile_lock:
--> 337     return callback(frame, cache_size, hooks)

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:404, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
    402 counters["frames"]["total"] += 1
    403 try:
--> 404     result = inner_convert(frame, cache_size, hooks)
    405     counters["frames"]["ok"] += 1
    406     return result

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:104, in wrap_convert_context.<locals>._fn(*args, **kwargs)
    102 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
    103 try:
--> 104     return fn(*args, **kwargs)
    105 finally:
    106     torch._C._set_grad_enabled(prior_grad_mode)

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:262, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
    259 global initial_grad_state
    260 initial_grad_state = torch.is_grad_enabled()
--> 262 return _compile(
    263     frame.f_code,
    264     frame.f_globals,
    265     frame.f_locals,
    266     frame.f_builtins,
    267     compiler_fn,
    268     one_graph,
    269     export,
    270     hooks,
    271     frame,
    272 )

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/utils.py:163, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
    161     compilation_metrics[key] = []
    162 t0 = time.time()
--> 163 r = func(*args, **kwargs)
    164 time_spent = time.time() - t0
    165 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")

File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:394, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    392 except Exception as e:
    393     exception_handler(e, code, frame)
--> 394     raise InternalTorchDynamoError() from e

InternalTorchDynamoError: 

Do you have a full repro?