I have the following custom class that Pytorch 2.0 doesn’t like when I run the training loop of my compiled VAE network:
class MaskedSequential(nn.Sequential):
"""
Same as nn.Sequential, but allows to pass the masks tensor in cascading manner
just as would do a nn.Sequential with the input tensor.
"""
def forward(self, *inputs):
x, masks = inputs
for module in self._modules.values():
if isinstance(module, ResConvBlock) or isinstance(module, ResConvTransposeBlock):
x = module(x, masks)
else:
x = module(x)
return x
I get the following error:
OSError Traceback (most recent call last)
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:324, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
323 try:
--> 324 out_code = transform_code_object(code, transform)
325 orig_code_map[out_code] = code
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:445, in transform_code_object(code, transformations, safe)
443 propagate_line_nums(instructions)
--> 445 transformations(instructions, code_options)
446 return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:311, in _compile.<locals>.transform(instructions, code_options)
299 tracer = InstructionTranslator(
300 instructions,
301 code,
(...)
309 mutated_closure_cell_contents,
310 )
--> 311 tracer.run()
312 output = tracer.output
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1726, in InstructionTranslator.run(self)
1725 _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1726 super().run()
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:576, in InstructionTranslatorBase.run(self)
572 self.output.push_tx(self)
573 while (
574 self.instruction_pointer is not None
575 and not self.output.should_exit
--> 576 and self.step()
577 ):
578 pass
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:540, in InstructionTranslatorBase.step(self)
539 unimplemented(f"missing: {inst.opname}")
--> 540 getattr(self, inst.opname)(inst)
542 return inst.opname != "RETURN_VALUE"
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:342, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
341 try:
--> 342 return inner_fn(self, inst)
343 except Unsupported as excp:
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:965, in InstructionTranslatorBase.CALL_FUNCTION(self, inst)
964 fn = self.pop()
--> 965 self.call_function(fn, args, {})
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:474, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
470 assert all(
471 isinstance(x, VariableTracker)
472 for x in itertools.chain(args, kwargs.values())
473 )
--> 474 self.push(fn.call_function(self, args, kwargs))
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py:244, in NNModuleVariable.call_function(self, tx, args, kwargs)
243 options["source"] = fn_source
--> 244 return tx.inline_user_function_return(
245 variables.UserFunctionVariable(fn, **options),
246 args,
247 kwargs,
248 )
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:510, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
509 try:
--> 510 result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
511 self.output.guards.update(fn.guards)
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1806, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
1805 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 1806 return cls.inline_call_(parent, func, args, kwargs)
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1862, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
1861 try:
-> 1862 tracer.run()
1863 except exc.SkipFrame as e:
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:576, in InstructionTranslatorBase.run(self)
572 self.output.push_tx(self)
573 while (
574 self.instruction_pointer is not None
575 and not self.output.should_exit
--> 576 and self.step()
577 ):
578 pass
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:540, in InstructionTranslatorBase.step(self)
539 unimplemented(f"missing: {inst.opname}")
--> 540 getattr(self, inst.opname)(inst)
542 return inst.opname != "RETURN_VALUE"
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:688, in InstructionTranslatorBase.LOAD_GLOBAL(self, inst)
687 source = self.get_global_source(name)
--> 688 self.push(VariableBuilder(self, source)(value))
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:172, in VariableBuilder.__call__(self, value)
171 return self.tx.output.side_effects[value]
--> 172 return self._wrap(value).clone(**self.options())
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:454, in VariableBuilder._wrap(self, value)
447 return TorchVariable(
448 value,
449 source=self.source,
450 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
451 )
452 elif (
453 istype(value, (type, types.FunctionType))
--> 454 and skipfiles.check(getfile(value), allow_torch=True)
455 and not inspect.getattr_static(value, "_torchdynamo_inline", False)
456 ):
457 return SkipFilesVariable(
458 value,
459 source=self.source,
460 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
461 )
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/utils.py:586, in getfile(obj)
585 try:
--> 586 return inspect.getfile(obj)
587 except TypeError:
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/package/package_importer.py:691, in _patched_getfile(object)
690 return _package_imported_modules[object.__module__].__file__
--> 691 return _orig_getfile(object)
File ~/mambaforge/envs/vae/lib/python3.10/inspect.py:785, in getfile(object)
784 if object.__module__ == '__main__':
--> 785 raise OSError('source code not available')
786 raise TypeError('{!r} is a built-in class'.format(object))
OSError: source code not available
from user code:
File "/tmp/ipykernel_3332697/3053697966.py", line 380, in forward
x = self.encoder(x, masks)
File "/tmp/ipykernel_3332697/3053697966.py", line 292, in forward
if isinstance(module, ResConvBlock) or isinstance(module, ResConvTransposeBlock):
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
The above exception was the direct cause of the following exception:
InternalTorchDynamoError Traceback (most recent call last)
Cell In[3], line 22
20 images = images.to(device)
21 masks = masks.to(device)
---> 22 images_pred, mu, log_var = model(images, masks)
23 loss = MaskedMSELoss()(images_pred, images, masks, mu, log_var)
24 # Exits autocast before backward().
25 # Backward passes under autocast are not recommended.
26 # Backward ops run in the same dtype autocast chose for corresponding forward ops.
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:82, in OptimizedModule.forward(self, *args, **kwargs)
81 def forward(self, *args, **kwargs):
---> 82 return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:209, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
207 dynamic_ctx.__enter__()
208 try:
--> 209 return fn(*args, **kwargs)
210 finally:
211 set_eval_frame(prior)
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:337, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
334 return hijacked_callback(frame, cache_size, hooks)
336 with compile_lock:
--> 337 return callback(frame, cache_size, hooks)
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:404, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
402 counters["frames"]["total"] += 1
403 try:
--> 404 result = inner_convert(frame, cache_size, hooks)
405 counters["frames"]["ok"] += 1
406 return result
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:104, in wrap_convert_context.<locals>._fn(*args, **kwargs)
102 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
103 try:
--> 104 return fn(*args, **kwargs)
105 finally:
106 torch._C._set_grad_enabled(prior_grad_mode)
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:262, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
259 global initial_grad_state
260 initial_grad_state = torch.is_grad_enabled()
--> 262 return _compile(
263 frame.f_code,
264 frame.f_globals,
265 frame.f_locals,
266 frame.f_builtins,
267 compiler_fn,
268 one_graph,
269 export,
270 hooks,
271 frame,
272 )
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/utils.py:163, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
161 compilation_metrics[key] = []
162 t0 = time.time()
--> 163 r = func(*args, **kwargs)
164 time_spent = time.time() - t0
165 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
File ~/mambaforge/envs/vae/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:394, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
392 except Exception as e:
393 exception_handler(e, code, frame)
--> 394 raise InternalTorchDynamoError() from e
InternalTorchDynamoError: