Hi,
Running everything using new official PyTorch docker image in a notebook on a 3090 GPU.
I have a more or less standard model with 1D convolutions and transformer modules:
JITModel(
(softmax): Softmax(dim=2)
(stft): Identity()
(audio_normalize): Identity()
(encoder): BasicEncoder(
(layers): Sequential(
(0): EncoderBlock(
(skip_add): SkipAddBlock()
(layers): Sequential(
(0): Conv1d(161, 512, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)
(1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Dropout(p=0.1, inplace=False)
(4): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
(5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): Dropout(p=0.1, inplace=False)
)
)
...
(8): TransformerLayer(
(attention): MultiHeadAttention(
(QKV): Linear(in_features=512, out_features=1536, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(softmax): Softmax(dim=-1)
)
(activation): ReLU()
(linear1): Linear(in_features=512, out_features=512, bias=True)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
...
(20): TransformerLayer(
(attention): MultiHeadAttention(
(QKV): Linear(in_features=512, out_features=1536, bias=True)
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(softmax): Softmax(dim=-1)
)
(activation): ReLU()
(linear1): Linear(in_features=512, out_features=512, bias=True)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
)
)
(fc): Conv1d(512, 998, kernel_size=(1,), stride=(1,))
(decoder): Identity()
(skip_decoder): Identity()
)
I had no problems exporting this model to .jit
and / or onnx
, all of the modules seem pretty standard.
When I try the following compilation:
import torch._dynamo
torch._dynamo.reset()
torch._dynamo.config.verbose=True
torch._dynamo.config.suppress_errors = True
model.stft = nn.Identity() # stft causes compiler errors, removed it for simplicity
model.decoder = nn.Identity() # removed this for simplicity
model.skip_decoder = nn.Identity() # removed these for simplicity
model.audio_normalize = nn.Identity() # removed these for simplicity
model.encoder.layers = model.encoder.layers[:8] # removed transformer layers, only plain 1D convolutions remain
def test(model):
with torch.no_grad():
batch_size = 10
inputs = torch.randn(batch_size, 161, 100 * 15) # just a fixed shape analogous to a 15 second audio
inputs = inputs.to(device)
model_outputs = model(inputs)
return True
compiled_test = torch.compile(test,
dynamic=False,
mode="reduce-overhead")
the compiler fails with the following error:
fake_example_inputs())
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
return compile_fx(
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
return aot_autograd(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
return inner_compile(
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 177, in compile_fx_inner
compiled_fn = graph.compile_to_fn()
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 586, in compile_to_fn
return self.compile_to_module().call
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 575, in compile_to_module
mod = PyCodeCache.load(code)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 528, in load
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_keras/og/cogya2y4dpehvdn65qmondwweu6i7skafpim2xqcvjymemciffdg.py", line 21, in <module>
triton__0 = async_compile.triton('''
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 683, in triton
future = self.process_pool().submit(
File "/opt/conda/lib/python3.10/concurrent/futures/process.py", line 715, in submit
raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
self.output.compile_subgraph(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore
[2023-03-23 06:30:15,040] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT forward /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py line 215
216 0 LOAD_FAST 0 (self)
2 GET_ITER
>> 4 FOR_ITER 6 (to 18)
6 STORE_FAST 2 (module)
217 8 LOAD_FAST 2 (module)
10 LOAD_FAST 1 (input)
12 CALL_FUNCTION 1
14 STORE_FAST 1 (input)
16 JUMP_ABSOLUTE 2 (to 4)
218 >> 18 LOAD_FAST 1 (input)
20 RETURN_VALUE
========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
return compile_fx(
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
return aot_autograd(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
return inner_compile(
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 177, in compile_fx_inner
compiled_fn = graph.compile_to_fn()
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 586, in compile_to_fn
return self.compile_to_module().call
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 575, in compile_to_module
mod = PyCodeCache.load(code)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 528, in load
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_keras/og/cogya2y4dpehvdn65qmondwweu6i7skafpim2xqcvjymemciffdg.py", line 21, in <module>
triton__0 = async_compile.triton('''
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 683, in triton
future = self.process_pool().submit(
File "/opt/conda/lib/python3.10/concurrent/futures/process.py", line 715, in submit
raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
self.output.compile_subgraph(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore
[2023-03-23 06:30:15,408] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT forward /home/keras/notebook/nvme/aveysov/silero-models-research/models/stt_model_blocks.py line 213
214 0 LOAD_FAST 1 (x)
2 STORE_FAST 2 (inputs)
215 4 LOAD_FAST 0 (self)
6 LOAD_METHOD 0 (layers)
8 LOAD_FAST 1 (x)
10 CALL_METHOD 1
12 STORE_FAST 1 (x)
216 14 LOAD_FAST 0 (self)
16 LOAD_ATTR 1 (skip)
18 POP_JUMP_IF_FALSE 17 (to 34)
217 20 LOAD_FAST 0 (self)
22 LOAD_ATTR 2 (skip_add)
24 LOAD_METHOD 3 (add)
26 LOAD_FAST 1 (x)
28 LOAD_FAST 2 (inputs)
30 CALL_METHOD 2
32 STORE_FAST 1 (x)
218 >> 34 LOAD_FAST 1 (x)
36 RETURN_VALUE
========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
return compile_fx(
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
return aot_autograd(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
return inner_compile(
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 177, in compile_fx_inner
compiled_fn = graph.compile_to_fn()
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 586, in compile_to_fn
return self.compile_to_module().call
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 575, in compile_to_module
mod = PyCodeCache.load(code)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 528, in load
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_keras/f7/cf7fsrsawsmeg76n43nynszodl72ngsqfiii6u46pkrudq5gsxgi.py", line 21, in <module>
triton__0 = async_compile.triton('''
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 683, in triton
future = self.process_pool().submit(
File "/opt/conda/lib/python3.10/concurrent/futures/process.py", line 715, in submit
raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
self.output.compile_subgraph(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore
[2023-03-23 06:30:15,433] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT add /home/keras/notebook/nvme/aveysov/silero-models-research/models/stt_model_blocks.py line 306
307 0 LOAD_FAST 1 (x)
2 LOAD_FAST 2 (skip)
4 BINARY_ADD
6 RETURN_VALUE
========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
return compile_fx(
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
return aot_autograd(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
return inner_compile(
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 177, in compile_fx_inner
compiled_fn = graph.compile_to_fn()
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 586, in compile_to_fn
return self.compile_to_module().call
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 575, in compile_to_module
mod = PyCodeCache.load(code)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 528, in load
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_keras/uc/cuc7cqzux2s7y6hjb6uvz76lzx6gosxhmfgrbewbrrqt453xu7kz.py", line 20, in <module>
triton__0 = async_compile.triton('''
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 683, in triton
future = self.process_pool().submit(
File "/opt/conda/lib/python3.10/concurrent/futures/process.py", line 715, in submit
raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
self.output.compile_subgraph(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore
However, if I go for the extreme, and do the following:
model.encoder = torch.nn.Sequential(*[torch.nn.Conv1d(161, 512, 5)])
The model compiles.
Then I try the following:
model.encoder = torch.nn.Sequential(*model.encoder.layers[0].layers)
which produces the following model:
JITModel(
(softmax): Softmax(dim=2)
(stft): Identity()
(audio_normalize): Identity()
(encoder): Sequential(
(0): Conv1d(161, 512, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)
(1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Dropout(p=0.1, inplace=False)
(4): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
(5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): Dropout(p=0.1, inplace=False)
)
(fc): Conv1d(512, 998, kernel_size=(1,), stride=(1,))
(decoder): Identity()
(skip_decoder): Identity()
)
And it still has a similar compilation error, which is stange, because the modules are standard and the model is very simple. I am not sure if I am doing something wrong or assuming some wrong assumptions.