Torch.compile error for recurrent model

Hi,

I am training a recurrent model. They model trains well and everything works with Pytorch 1.13. I am trying to move to Pytorch 2.0. The only thing that I add to my code is

torch._logging.set_logs(inductor=logging.DEBUG)
model = torch.compile(model)

but the code seems to hang during the compilation

Start training
/home/roman/anaconda3/envs/clf2v2v1/lib/python3.10/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/home/roman/anaconda3/envs/clf2v2v1/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/home/roman/anaconda3/envs/clf2v2v1/lib/python3.10/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/home/roman/anaconda3/envs/clf2v2v1/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
[2023-12-27 17:58:30,274] [0/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 0
[2023-12-27 17:58:30,557] [0/0] torch._inductor.graph: [DEBUG] Force channels last inputs for 0 conv for the current graph with id 0
[2023-12-27 17:58:30,595] [0/0] torch._inductor.scheduler: [INFO] Number of scheduler nodes after fusion 18
[2023-12-27 17:58:33,808] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,808] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,814] [0/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,814] [0/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,834] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,834] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,838] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,838] torch._inductor.triton_heuristics: [DEBUG] num_warps: 4, num_stages: 2
[2023-12-27 17:58:33,841] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,841] torch._inductor.triton_heuristics: [DEBUG] num_warps: 4, num_stages: 2
[2023-12-27 17:58:33,842] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,842] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,857] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,857] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,858] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,858] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,875] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,875] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,884] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,884] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, num_warps: 1, num_stages: 1
[2023-12-27 17:58:33,914] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,914] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 1, num_stages: 1
[2023-12-27 17:58:33,915] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,915] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, num_warps: 4, num_stages: 1
...
...
[2023-12-27 17:58:39,268] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, YBLOCK: 32, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,268] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 64, YBLOCK: 64, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,268] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, YBLOCK: 16, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,268] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1024, YBLOCK: 1, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,268] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 16, YBLOCK: 64, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,283] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 4 configs
[2023-12-27 17:58:39,283] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,283] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 8, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,283] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,283] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,288] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:39,288] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, RBLOCK: 2048, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,292] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 3 configs
[2023-12-27 17:58:39,292] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,292] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 8, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,292] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,301] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:39,301] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 1, num_stages: 1
[2023-12-27 17:58:39,301] [3/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 4 configs
[2023-12-27 17:58:39,301] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,301] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 8, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,301] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,301] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 5 configs
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, YBLOCK: 32, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 64, YBLOCK: 64, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, YBLOCK: 16, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1024, YBLOCK: 1, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 16, YBLOCK: 64, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 5 configs
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, RBLOCK: 2048, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, RBLOCK: 8, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 64, RBLOCK: 64, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 8, RBLOCK: 512, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 64, RBLOCK: 4, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,353] [3/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 3 configs
[2023-12-27 17:58:39,353] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,353] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 8, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,353] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,361] [3/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:39,361] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 1, num_stages: 1
[2023-12-27 17:58:39,363] [3/0] torch._inductor.graph: [DEBUG] Output code written to: /tmp/torchinductor_roman/7n/c7nxhlduwd4y55b643ceos54ajsbrvzorfcyzboivfiiv2jowl6q.py
[2023-12-27 17:58:39,364] [3/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 1
[2023-12-27 17:58:39,396] torch._inductor.cudagraph_trees: [INFO] recording cudagraph tree for (2, 32)
[2023-12-27 17:58:39,571] torch._inductor.cudagraph_trees: [DEBUG] Running warmup of function 0
[2023-12-27 17:58:40,814] [7/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 2
[2023-12-27 17:58:40,856] [7/0] torch._inductor.graph: [DEBUG] Force channels last inputs for 0 conv for the current graph with id 2
[2023-12-27 17:58:40,858] [7/0] torch._inductor.scheduler: [INFO] Number of scheduler nodes after fusion 1
[2023-12-27 17:58:40,863] [7/0] torch._inductor.graph: [DEBUG] Output code written to: /tmp/torchinductor_roman/23/c23erqwa2aie4tfk4lvzbs5z22k676piklaiuswiis6qhhu5o5j3.py
skipping cudagraphs due to multiple devices
[2023-12-27 17:58:40,863] [7/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 2
[2023-12-27 17:58:40,921] [8/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 3
[2023-12-27 17:58:40,926] [8/0] torch._inductor.graph: [DEBUG] Force channels last inputs for 0 conv for the current graph with id 3
[2023-12-27 17:58:40,927] [8/0] torch._inductor.scheduler: [INFO] Number of scheduler nodes after fusion 1
[2023-12-27 17:58:40,930] [8/0] torch._inductor.graph: [DEBUG] Output code written to: /tmp/torchinductor_roman/lf/clfvdj2m5j42hcxjcpgueo7fuaaiy5mj6yfo2icwjabeiz43wuvx.py
skipping cudagraphs for unknown reason
[2023-12-27 17:58:40,930] [8/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 3
[2023-12-27 17:58:40,983] [9/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 4
[2023-12-27 17:58:40,987] [9/0] torch._inductor.graph: [DEBUG] Force channels last inputs for 0 conv for the current graph with id 4
[2023-12-27 17:58:40,989] [9/0] torch._inductor.scheduler: [INFO] Number of scheduler nodes after fusion 1
[2023-12-27 17:58:40,992] [9/0] torch._inductor.graph: [DEBUG] Output code written to: /tmp/torchinductor_roman/dq/cdq5nqwfk2smfdyga26bw44q45b77zjign43sxrdxsm53ietwdbm.py
skipping cudagraphs for unknown reason
[2023-12-27 17:58:40,992] [9/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 4

I tried using pytorch 2.0.1, and then the training runs but I end up create this ever increasing graph (torchinductor compiling FORWARDS graph 700+). I tried adding graph break using torch._dynamo.graph_break() but that does not help and using torch.compile(dynamic=True) gives me a different error.

The really long output in this case looks as below

Start training
[2023-12-27 19:11:42,042] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-12-27 19:11:42,103] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-12-27 19:11:43,923] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 0
...
...
[2023-12-27 19:31:03,674] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 121
[2023-12-27 19:31:03,741] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling BACKWARDS graph 120
[2023-12-27 19:31:06,912] torch._inductor.graph: [INFO] Using FallbackKernel: aten.index
[2023-12-27 19:31:09,331] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 120
Epoch: 2 	 | 	 loss: 0.33797 		 | 	 time taken: 50.78308
Epoch: 3 	 | 	 loss: 0.50044 		 | 	 time taken: 3.25101
...
Epoch: 99 	 | 	 loss: 0.39229 		 | 	 time taken: 3.40079
Epoch: 100 	 | 	 loss: 0.31362 		 | 	 time taken: 3.71140
[2023-12-27 19:37:26,737] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-12-27 19:37:26,757] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-12-27 19:37:26,774] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-12-27 19:37:26,789] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-12-27 19:37:26,802] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 126
[2023-12-27 19:37:26,828] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 126
[2023-12-27 19:37:26,828] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-12-27 19:37:26,838] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing proj_out
[2023-12-27 19:37:26,843] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing rearrange
...
...
[2023-12-27 19:37:46,862] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in forward>
[2023-12-27 19:37:47,383] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing <graph break in forward> (RETURN_VALUE)
[2023-12-27 19:37:47,392] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-12-27 19:37:49,019] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 154
[2023-12-27 19:37:49,042] torch._inductor.graph: [INFO] Using FallbackKernel: aten.cumsum
[2023-12-27 19:37:49,114] torch._inductor.graph: [INFO] Creating implicit fallback for:
  target: aten._scaled_dot_product_efficient_attention.default
  args[0]: TensorBox(
    View(
      ReinterpretView(
        StorageBox(
          ExternKernelOut(
            name=buf42,
            layout=FixedLayout('cuda', torch.float32, size=[42, 96], stride=[96, 1]),
            inputs=[ReinterpretView(
              StorageBox(
                InputBuffer(name='arg7_1', layout=FixedLayout('cuda', torch.float32, size=[288], stride=[1]))
              ),
              FixedLayout('cuda', torch.float32, size=[96], stride=[1]),
              no origins?
            ), ReinterpretView(
              StorageBox(
                ComputedBuffer(name='buf41', layout=FixedLayout('cuda', torch.float32, size=[42, 1, 96], stride=[96, 96, 1]), data=Pointwise(
                  'cuda',
                  torch.float32,
                  tmp0 = load(buf40, i2 + 96 * i0)
                  tmp1 = load(arg112_1, i2 + 96 * i0)
                  tmp2 = tmp0 + tmp1
                  return tmp2
                  ,
                  ranges=[42, 1, 96],
                  origins={arg0_1, addmm_2, mul_5, var_mean, bmm_1, view_21, permute_11, permute_9, add_6, full_like, add_9, arg4_1, view_22, view_11, sub_1, view_12, view_24, view_28, permute_7, view_9, scalar_tensor, view_25, amax, div_10, sub, addmm_3, view_31, view_13, permute_3, arg3_1, permute_4, split, addmm, view_29, view_23, view_10, addmm_1, expand_3, view_19, view_6, sum_1, permute_10, arg1_1, add_7, arg5_1, add_8, view_27, arg2_1, div_8, arg112_1, add_5, rsqrt, view_5, add_4, view_26, view_17, permute_8, bmm, view_7, view_20, mul_6, clone, exp, arg111_1, permute_5, view_32, view_8, permute_6, split_1, view_30, arg113_1, where, view_18, div_9}
                ))
              ),
              FixedLayout('cuda', torch.float32, size=(42, 96), stride=[96, 1]),
              no origins?
            ), ReinterpretView(
              StorageBox(
                InputBuffer(name='arg6_1', layout=FixedLayout('cuda', torch.float32, size=[288, 96], stride=[96, 1]))
              ),
              FixedLayout('cuda', torch.float32, size=[96, 96], stride=[1, 96]),
              no origins?
            )],
            constant_args=(),
            kwargs={'alpha': 1, 'beta': 1},
            output_view=None,
            origins={arg0_1, addmm_2, mul_5, var_mean, bmm_1, view_21, permute_11, permute_9, add_6, full_like, add_9, arg4_1, view_22, view_11, view_24, sub_1, view_12, view_28, permute_7, view_9, split_3, scalar_tensor, view_25, amax, div_10, sub, addmm_3, view_13, view_31, permute_3, arg3_1, permute_4, split, addmm, view_29, view_23, view_10, addmm_1, expand_3, view_19, split_2, view_6, sum_1, permute_10, arg1_1, arg5_1, add_7, add_8, view_27, permute_12, arg7_1, arg2_1, div_8, arg112_1, add_5, rsqrt, view_5, add_4, view_26, view_17, permute_8, view_33, bmm, view_7, view_20, arg6_1, addmm_4, mul_6, clone, exp, arg111_1, permute_5, view_32, view_8, permute_6, split_1, view_30, arg113_1, where, view_18, div_9}
          )
        ),
        FixedLayout('cuda', torch.float32, size=[8, 42, 12], stride=[12, 96, 1]),
        no origins?
      ),
      size=(1, 8, 42, 12),
      reindex=lambda i0, i1, i2, i3: [i1, i2, i3],
      origins={arg0_1, addmm_2, mul_5, var_mean, bmm_1, view_21, permute_11, permute_9, add_6, full_like, add_9, arg4_1, view_22, view_11, view_24, sub_1, view_12, view_28, permute_7, view_9, split_3, scalar_tensor, view_25, amax, div_10, view_34, sub, addmm_3, view_13, view_31, permute_3, arg3_1, permute_4, split, addmm, view_29, view_23, view_10, addmm_1, expand_3, view_19, split_2, view_6, sum_1, permute_10, arg1_1, arg5_1, add_7, add_8, view_27, permute_12, arg7_1, arg2_1, div_8, arg112_1, add_5, rsqrt, view_5, add_4, view_26, view_17, permute_8, view_33, bmm, view_7, view_20, arg6_1, addmm_4, mul_6, clone, view_39, permute_15, exp, arg111_1, permute_5, view_32, view_8, permute_6, split_1, view_30, arg113_1, view_42, where, view_18, div_9}
    )
  )
  args[1]: TensorBox(
    View(
      PermuteView(data=View(
        StorageBox(
          Pointwise(
            'cuda',
            torch.float32,
            tmp0 = load(buf44, i2 + 96 * i0)
            tmp1 = load(arg7_1, 96 + i2)
            tmp2 = tmp0 + tmp1
            return tmp2
            ,
            ranges=[32768, 1, 96],
            origins={arg7_1, add_11, add_10, view_35, view_3, mm, view_36, split_3, permute_13, permute_1, permute_2, arg110_1, arg6_1, split_2}
          )
        ),
        size=(32768, 8, 12),
        reindex=lambda i0, i1, i2: [i0, 0, 12*i1 + i2],
        origins={arg7_1, add_11, add_10, view_3, view_35, mm, view_36, split_3, permute_13, permute_1, permute_2, arg110_1, arg6_1, split_2, view_40}
      ), dims=[1, 0, 2]),
      size=(1, 8, 32768, 12),
      reindex=lambda i0, i1, i2, i3: [i1, i2, i3],
      origins={arg7_1, view_43, view_35, mm, permute_13, permute_1, arg110_1, arg6_1, split_2, permute_16, add_11, add_10, view_3, view_36, split_3, permute_2, view_40}
    )
  )
  args[2]: TensorBox(
    View(
      PermuteView(data=View(
        StorageBox(
          Pointwise(
            'cuda',
            torch.float32,
            tmp0 = load(buf45, i2 + 96 * i0)
            tmp1 = load(arg7_1, 192 + i2)
            tmp2 = tmp0 + tmp1
            return tmp2
            ,
            ranges=[32768, 1, 96],
            origins={arg7_1, add_12, view_3, split_3, view_37, permute_14, mm_1, permute_1, arg6_1, view_38, arg110_1, split_2}
          )
        ),
        size=(32768, 8, 12),
        reindex=lambda i0, i1, i2: [i0, 0, 12*i1 + i2],
        origins={arg7_1, add_12, view_3, split_3, view_37, permute_14, mm_1, permute_1, view_41, arg6_1, view_38, arg110_1, split_2}
      ), dims=[1, 0, 2]),
      size=(1, 8, 32768, 12),
      reindex=lambda i0, i1, i2, i3: [i1, i2, i3],
      origins={arg7_1, add_12, view_3, split_3, view_37, permute_14, view_44, mm_1, permute_1, view_41, arg6_1, view_38, arg110_1, permute_17, split_2}
    )
  )
  args[3]: False
...
[2023-12-27 19:37:50,512] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 158
[2023-12-27 19:37:50,546] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 158
[2023-12-27 19:37:50,547] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-12-27 19:37:50,571] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in create_list_from_pred>
Eval stats: ...

Any idea, how to make it work or why does pytorch 2.1.1 hangs?

1 Like

Just a guess, but - given that you mentioned a recurrent model / long compile times / large graphs - if your model has a for loop with a large iteration range, dynamo (by default) will try to trace through and unroll the for loop at compile time into a giant graph. If the loop is long enough, unrolling it can cause you to get some pretty bad compile times.

Some hopefully relevant work: [feature request] `torch.scan` (also port `lax.fori_loop` / `lax.while_loop` / `lax.associative_scan` and hopefully parallelized associative scans) · Issue #50688 · pytorch/pytorch · GitHub

Otherwise, one option is to torch.compile the body of your loop, and call that compiled code from an outer loop (one issue you might hit is that depending on how you use the loop iteration variables, you could get recompiles):

compiled_body = torch.compile(loop_body)
for i in range(loop_iterations):
    out = compiled_body(out, i)
1 Like