Model compilation error - BrokenProcessPool

Hi,

Running everything using new official PyTorch docker image in a notebook on a 3090 GPU.

I have a more or less standard model with 1D convolutions and transformer modules:

JITModel(
  (softmax): Softmax(dim=2)
  (stft): Identity()
  (audio_normalize): Identity()
  (encoder): BasicEncoder(
    (layers): Sequential(
      (0): EncoderBlock(
        (skip_add): SkipAddBlock()
        (layers): Sequential(
          (0): Conv1d(161, 512, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)
          (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Dropout(p=0.1, inplace=False)
          (4): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
          (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): Dropout(p=0.1, inplace=False)
        )
      )
...
      (8): TransformerLayer(
        (attention): MultiHeadAttention(
          (QKV): Linear(in_features=512, out_features=1536, bias=True)
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
          (softmax): Softmax(dim=-1)
        )
        (activation): ReLU()
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
 ...
      (20): TransformerLayer(
        (attention): MultiHeadAttention(
          (QKV): Linear(in_features=512, out_features=1536, bias=True)
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
          (softmax): Softmax(dim=-1)
        )
        (activation): ReLU()
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc): Conv1d(512, 998, kernel_size=(1,), stride=(1,))
  (decoder): Identity()
  (skip_decoder): Identity()
)

I had no problems exporting this model to .jit and / or onnx, all of the modules seem pretty standard.

When I try the following compilation:

import torch._dynamo
torch._dynamo.reset()
torch._dynamo.config.verbose=True
torch._dynamo.config.suppress_errors = True

model.stft = nn.Identity()  # stft causes compiler errors, removed it for simplicity
model.decoder = nn.Identity() # removed this for simplicity
model.skip_decoder = nn.Identity() # removed these for simplicity
model.audio_normalize = nn.Identity() # removed these for simplicity
model.encoder.layers = model.encoder.layers[:8]  # removed transformer layers, only plain 1D convolutions remain

def test(model):
    with torch.no_grad():
        batch_size = 10
        inputs = torch.randn(batch_size, 161, 100 * 15)  # just a fixed shape analogous to a 15 second audio
        inputs = inputs.to(device)
        model_outputs = model(inputs)
        return True
    
compiled_test = torch.compile(test,
                              dynamic=False,
                              mode="reduce-overhead")   

the compiler fails with the following error:

fake_example_inputs())
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
    return compile_fx(
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
    return aot_autograd(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
    return inner_compile(
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 177, in compile_fx_inner
    compiled_fn = graph.compile_to_fn()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 586, in compile_to_fn
    return self.compile_to_module().call
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 575, in compile_to_module
    mod = PyCodeCache.load(code)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 528, in load
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_keras/og/cogya2y4dpehvdn65qmondwweu6i7skafpim2xqcvjymemciffdg.py", line 21, in <module>
    triton__0 = async_compile.triton('''
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 683, in triton
    future = self.process_pool().submit(
  File "/opt/conda/lib/python3.10/concurrent/futures/process.py", line 715, in submit
    raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore


[2023-03-23 06:30:15,040] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT forward /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py line 215 
216           0 LOAD_FAST                0 (self)
              2 GET_ITER
        >>    4 FOR_ITER                 6 (to 18)
              6 STORE_FAST               2 (module)

217           8 LOAD_FAST                2 (module)
             10 LOAD_FAST                1 (input)
             12 CALL_FUNCTION            1
             14 STORE_FAST               1 (input)
             16 JUMP_ABSOLUTE            2 (to 4)

218     >>   18 LOAD_FAST                1 (input)
             20 RETURN_VALUE

 ========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
    return compile_fx(
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
    return aot_autograd(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
    return inner_compile(
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 177, in compile_fx_inner
    compiled_fn = graph.compile_to_fn()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 586, in compile_to_fn
    return self.compile_to_module().call
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 575, in compile_to_module
    mod = PyCodeCache.load(code)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 528, in load
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_keras/og/cogya2y4dpehvdn65qmondwweu6i7skafpim2xqcvjymemciffdg.py", line 21, in <module>
    triton__0 = async_compile.triton('''
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 683, in triton
    future = self.process_pool().submit(
  File "/opt/conda/lib/python3.10/concurrent/futures/process.py", line 715, in submit
    raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore


[2023-03-23 06:30:15,408] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT forward /home/keras/notebook/nvme/aveysov/silero-models-research/models/stt_model_blocks.py line 213 
214           0 LOAD_FAST                1 (x)
              2 STORE_FAST               2 (inputs)

215           4 LOAD_FAST                0 (self)
              6 LOAD_METHOD              0 (layers)
              8 LOAD_FAST                1 (x)
             10 CALL_METHOD              1
             12 STORE_FAST               1 (x)

216          14 LOAD_FAST                0 (self)
             16 LOAD_ATTR                1 (skip)
             18 POP_JUMP_IF_FALSE       17 (to 34)

217          20 LOAD_FAST                0 (self)
             22 LOAD_ATTR                2 (skip_add)
             24 LOAD_METHOD              3 (add)
             26 LOAD_FAST                1 (x)
             28 LOAD_FAST                2 (inputs)
             30 CALL_METHOD              2
             32 STORE_FAST               1 (x)

218     >>   34 LOAD_FAST                1 (x)
             36 RETURN_VALUE

 ========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
    return compile_fx(
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
    return aot_autograd(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
    return inner_compile(
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 177, in compile_fx_inner
    compiled_fn = graph.compile_to_fn()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 586, in compile_to_fn
    return self.compile_to_module().call
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 575, in compile_to_module
    mod = PyCodeCache.load(code)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 528, in load
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_keras/f7/cf7fsrsawsmeg76n43nynszodl72ngsqfiii6u46pkrudq5gsxgi.py", line 21, in <module>
    triton__0 = async_compile.triton('''
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 683, in triton
    future = self.process_pool().submit(
  File "/opt/conda/lib/python3.10/concurrent/futures/process.py", line 715, in submit
    raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore


[2023-03-23 06:30:15,433] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT add /home/keras/notebook/nvme/aveysov/silero-models-research/models/stt_model_blocks.py line 306 
307           0 LOAD_FAST                1 (x)
              2 LOAD_FAST                2 (skip)
              4 BINARY_ADD
              6 RETURN_VALUE

 ========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
    return compile_fx(
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
    return aot_autograd(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
    return inner_compile(
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 177, in compile_fx_inner
    compiled_fn = graph.compile_to_fn()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 586, in compile_to_fn
    return self.compile_to_module().call
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 575, in compile_to_module
    mod = PyCodeCache.load(code)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 528, in load
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_keras/uc/cuc7cqzux2s7y6hjb6uvz76lzx6gosxhmfgrbewbrrqt453xu7kz.py", line 20, in <module>
    triton__0 = async_compile.triton('''
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 683, in triton
    future = self.process_pool().submit(
  File "/opt/conda/lib/python3.10/concurrent/futures/process.py", line 715, in submit
    raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore

However, if I go for the extreme, and do the following:

model.encoder = torch.nn.Sequential(*[torch.nn.Conv1d(161, 512, 5)])

The model compiles.
Then I try the following:

model.encoder = torch.nn.Sequential(*model.encoder.layers[0].layers)

which produces the following model:

JITModel(
  (softmax): Softmax(dim=2)
  (stft): Identity()
  (audio_normalize): Identity()
  (encoder): Sequential(
    (0): Conv1d(161, 512, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.1, inplace=False)
    (4): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): Dropout(p=0.1, inplace=False)
  )
  (fc): Conv1d(512, 998, kernel_size=(1,), stride=(1,))
  (decoder): Identity()
  (skip_decoder): Identity()
)

And it still has a similar compilation error, which is stange, because the modules are standard and the model is very simple. I am not sure if I am doing something wrong or assuming some wrong assumptions.

Perhaps @yanboliang knows

I am having the same issue.


File “/home/luban/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py”, line 1792, in RETURN_VALUE
self.output.compile_subgraph(
File “/home/luban/miniconda3/lib/python3.8/site-packages/torch/_dynamo/output_graph.py”, line 541, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File “/home/luban/miniconda3/lib/python3.8/site-packages/torch/_dynamo/output_graph.py”, line 588, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File “/home/luban/miniconda3/lib/python3.8/site-packages/torch/_dynamo/utils.py”, line 163, in time_wrapper
r = func(*args, **kwargs)
File “/home/luban/miniconda3/lib/python3.8/site-packages/torch/_dynamo/output_graph.py”, line 675, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: compile_fn raised BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore

@leaf_y By the way, do you run a plain py script, or also use IPython / notebook?

I ran py script. By monitoring the processes while dynamo being compiling, it looks like there are huge mount of processes being dispatched until it throws out this BrokenProcessPool error.

I see, many thanks. At least one more data point.