Hi,
I am training a recurrent model. They model trains well and everything works with Pytorch 1.13. I am trying to move to Pytorch 2.0. The only thing that I add to my code is
torch._logging.set_logs(inductor=logging.DEBUG)
model = torch.compile(model)
but the code seems to hang during the compilation
Start training
/home/roman/anaconda3/envs/clf2v2v1/lib/python3.10/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
torch.has_cuda,
/home/roman/anaconda3/envs/clf2v2v1/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
torch.has_cudnn,
/home/roman/anaconda3/envs/clf2v2v1/lib/python3.10/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
torch.has_mps,
/home/roman/anaconda3/envs/clf2v2v1/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
torch.has_mkldnn,
[2023-12-27 17:58:30,274] [0/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 0
[2023-12-27 17:58:30,557] [0/0] torch._inductor.graph: [DEBUG] Force channels last inputs for 0 conv for the current graph with id 0
[2023-12-27 17:58:30,595] [0/0] torch._inductor.scheduler: [INFO] Number of scheduler nodes after fusion 18
[2023-12-27 17:58:33,808] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,808] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,814] [0/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,814] [0/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,834] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,834] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,838] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,838] torch._inductor.triton_heuristics: [DEBUG] num_warps: 4, num_stages: 2
[2023-12-27 17:58:33,841] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,841] torch._inductor.triton_heuristics: [DEBUG] num_warps: 4, num_stages: 2
[2023-12-27 17:58:33,842] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,842] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,857] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,857] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,858] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,858] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,875] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,875] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, num_warps: 4, num_stages: 1
[2023-12-27 17:58:33,884] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,884] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, num_warps: 1, num_stages: 1
[2023-12-27 17:58:33,914] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,914] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 1, num_stages: 1
[2023-12-27 17:58:33,915] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:33,915] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, num_warps: 4, num_stages: 1
...
...
[2023-12-27 17:58:39,268] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, YBLOCK: 32, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,268] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 64, YBLOCK: 64, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,268] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, YBLOCK: 16, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,268] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1024, YBLOCK: 1, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,268] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 16, YBLOCK: 64, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,283] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 4 configs
[2023-12-27 17:58:39,283] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,283] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 8, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,283] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,283] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,288] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:39,288] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, RBLOCK: 2048, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,292] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 3 configs
[2023-12-27 17:58:39,292] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,292] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 8, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,292] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,301] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:39,301] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 1, num_stages: 1
[2023-12-27 17:58:39,301] [3/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 4 configs
[2023-12-27 17:58:39,301] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,301] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 8, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,301] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,301] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 5 configs
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, YBLOCK: 32, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 64, YBLOCK: 64, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 256, YBLOCK: 16, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1024, YBLOCK: 1, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,320] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 16, YBLOCK: 64, num_warps: 4, num_stages: 1
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 5 configs
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, RBLOCK: 2048, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 128, RBLOCK: 8, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 64, RBLOCK: 64, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 8, RBLOCK: 512, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,338] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 64, RBLOCK: 4, num_warps: 8, num_stages: 1
[2023-12-27 17:58:39,353] [3/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 3 configs
[2023-12-27 17:58:39,353] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 1, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,353] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 8, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,353] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 2, num_stages: 1
[2023-12-27 17:58:39,361] [3/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2023-12-27 17:58:39,361] [3/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 32, num_warps: 1, num_stages: 1
[2023-12-27 17:58:39,363] [3/0] torch._inductor.graph: [DEBUG] Output code written to: /tmp/torchinductor_roman/7n/c7nxhlduwd4y55b643ceos54ajsbrvzorfcyzboivfiiv2jowl6q.py
[2023-12-27 17:58:39,364] [3/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 1
[2023-12-27 17:58:39,396] torch._inductor.cudagraph_trees: [INFO] recording cudagraph tree for (2, 32)
[2023-12-27 17:58:39,571] torch._inductor.cudagraph_trees: [DEBUG] Running warmup of function 0
[2023-12-27 17:58:40,814] [7/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 2
[2023-12-27 17:58:40,856] [7/0] torch._inductor.graph: [DEBUG] Force channels last inputs for 0 conv for the current graph with id 2
[2023-12-27 17:58:40,858] [7/0] torch._inductor.scheduler: [INFO] Number of scheduler nodes after fusion 1
[2023-12-27 17:58:40,863] [7/0] torch._inductor.graph: [DEBUG] Output code written to: /tmp/torchinductor_roman/23/c23erqwa2aie4tfk4lvzbs5z22k676piklaiuswiis6qhhu5o5j3.py
skipping cudagraphs due to multiple devices
[2023-12-27 17:58:40,863] [7/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 2
[2023-12-27 17:58:40,921] [8/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 3
[2023-12-27 17:58:40,926] [8/0] torch._inductor.graph: [DEBUG] Force channels last inputs for 0 conv for the current graph with id 3
[2023-12-27 17:58:40,927] [8/0] torch._inductor.scheduler: [INFO] Number of scheduler nodes after fusion 1
[2023-12-27 17:58:40,930] [8/0] torch._inductor.graph: [DEBUG] Output code written to: /tmp/torchinductor_roman/lf/clfvdj2m5j42hcxjcpgueo7fuaaiy5mj6yfo2icwjabeiz43wuvx.py
skipping cudagraphs for unknown reason
[2023-12-27 17:58:40,930] [8/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 3
[2023-12-27 17:58:40,983] [9/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 4
[2023-12-27 17:58:40,987] [9/0] torch._inductor.graph: [DEBUG] Force channels last inputs for 0 conv for the current graph with id 4
[2023-12-27 17:58:40,989] [9/0] torch._inductor.scheduler: [INFO] Number of scheduler nodes after fusion 1
[2023-12-27 17:58:40,992] [9/0] torch._inductor.graph: [DEBUG] Output code written to: /tmp/torchinductor_roman/dq/cdq5nqwfk2smfdyga26bw44q45b77zjign43sxrdxsm53ietwdbm.py
skipping cudagraphs for unknown reason
[2023-12-27 17:58:40,992] [9/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 4
I tried using pytorch 2.0.1, and then the training runs but I end up create this ever increasing graph (torchinductor compiling FORWARDS graph 700+). I tried adding graph break using torch._dynamo.graph_break()
but that does not help and using torch.compile(dynamic=True)
gives me a different error.
The really long output in this case looks as below
Start training
[2023-12-27 19:11:42,042] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-12-27 19:11:42,103] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-12-27 19:11:43,923] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 0
...
...
[2023-12-27 19:31:03,674] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 121
[2023-12-27 19:31:03,741] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling BACKWARDS graph 120
[2023-12-27 19:31:06,912] torch._inductor.graph: [INFO] Using FallbackKernel: aten.index
[2023-12-27 19:31:09,331] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 120
Epoch: 2 | loss: 0.33797 | time taken: 50.78308
Epoch: 3 | loss: 0.50044 | time taken: 3.25101
...
Epoch: 99 | loss: 0.39229 | time taken: 3.40079
Epoch: 100 | loss: 0.31362 | time taken: 3.71140
[2023-12-27 19:37:26,737] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-12-27 19:37:26,757] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-12-27 19:37:26,774] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-12-27 19:37:26,789] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-12-27 19:37:26,802] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 126
[2023-12-27 19:37:26,828] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 126
[2023-12-27 19:37:26,828] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-12-27 19:37:26,838] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing proj_out
[2023-12-27 19:37:26,843] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing rearrange
...
...
[2023-12-27 19:37:46,862] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in forward>
[2023-12-27 19:37:47,383] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing <graph break in forward> (RETURN_VALUE)
[2023-12-27 19:37:47,392] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-12-27 19:37:49,019] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 154
[2023-12-27 19:37:49,042] torch._inductor.graph: [INFO] Using FallbackKernel: aten.cumsum
[2023-12-27 19:37:49,114] torch._inductor.graph: [INFO] Creating implicit fallback for:
target: aten._scaled_dot_product_efficient_attention.default
args[0]: TensorBox(
View(
ReinterpretView(
StorageBox(
ExternKernelOut(
name=buf42,
layout=FixedLayout('cuda', torch.float32, size=[42, 96], stride=[96, 1]),
inputs=[ReinterpretView(
StorageBox(
InputBuffer(name='arg7_1', layout=FixedLayout('cuda', torch.float32, size=[288], stride=[1]))
),
FixedLayout('cuda', torch.float32, size=[96], stride=[1]),
no origins?
), ReinterpretView(
StorageBox(
ComputedBuffer(name='buf41', layout=FixedLayout('cuda', torch.float32, size=[42, 1, 96], stride=[96, 96, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf40, i2 + 96 * i0)
tmp1 = load(arg112_1, i2 + 96 * i0)
tmp2 = tmp0 + tmp1
return tmp2
,
ranges=[42, 1, 96],
origins={arg0_1, addmm_2, mul_5, var_mean, bmm_1, view_21, permute_11, permute_9, add_6, full_like, add_9, arg4_1, view_22, view_11, sub_1, view_12, view_24, view_28, permute_7, view_9, scalar_tensor, view_25, amax, div_10, sub, addmm_3, view_31, view_13, permute_3, arg3_1, permute_4, split, addmm, view_29, view_23, view_10, addmm_1, expand_3, view_19, view_6, sum_1, permute_10, arg1_1, add_7, arg5_1, add_8, view_27, arg2_1, div_8, arg112_1, add_5, rsqrt, view_5, add_4, view_26, view_17, permute_8, bmm, view_7, view_20, mul_6, clone, exp, arg111_1, permute_5, view_32, view_8, permute_6, split_1, view_30, arg113_1, where, view_18, div_9}
))
),
FixedLayout('cuda', torch.float32, size=(42, 96), stride=[96, 1]),
no origins?
), ReinterpretView(
StorageBox(
InputBuffer(name='arg6_1', layout=FixedLayout('cuda', torch.float32, size=[288, 96], stride=[96, 1]))
),
FixedLayout('cuda', torch.float32, size=[96, 96], stride=[1, 96]),
no origins?
)],
constant_args=(),
kwargs={'alpha': 1, 'beta': 1},
output_view=None,
origins={arg0_1, addmm_2, mul_5, var_mean, bmm_1, view_21, permute_11, permute_9, add_6, full_like, add_9, arg4_1, view_22, view_11, view_24, sub_1, view_12, view_28, permute_7, view_9, split_3, scalar_tensor, view_25, amax, div_10, sub, addmm_3, view_13, view_31, permute_3, arg3_1, permute_4, split, addmm, view_29, view_23, view_10, addmm_1, expand_3, view_19, split_2, view_6, sum_1, permute_10, arg1_1, arg5_1, add_7, add_8, view_27, permute_12, arg7_1, arg2_1, div_8, arg112_1, add_5, rsqrt, view_5, add_4, view_26, view_17, permute_8, view_33, bmm, view_7, view_20, arg6_1, addmm_4, mul_6, clone, exp, arg111_1, permute_5, view_32, view_8, permute_6, split_1, view_30, arg113_1, where, view_18, div_9}
)
),
FixedLayout('cuda', torch.float32, size=[8, 42, 12], stride=[12, 96, 1]),
no origins?
),
size=(1, 8, 42, 12),
reindex=lambda i0, i1, i2, i3: [i1, i2, i3],
origins={arg0_1, addmm_2, mul_5, var_mean, bmm_1, view_21, permute_11, permute_9, add_6, full_like, add_9, arg4_1, view_22, view_11, view_24, sub_1, view_12, view_28, permute_7, view_9, split_3, scalar_tensor, view_25, amax, div_10, view_34, sub, addmm_3, view_13, view_31, permute_3, arg3_1, permute_4, split, addmm, view_29, view_23, view_10, addmm_1, expand_3, view_19, split_2, view_6, sum_1, permute_10, arg1_1, arg5_1, add_7, add_8, view_27, permute_12, arg7_1, arg2_1, div_8, arg112_1, add_5, rsqrt, view_5, add_4, view_26, view_17, permute_8, view_33, bmm, view_7, view_20, arg6_1, addmm_4, mul_6, clone, view_39, permute_15, exp, arg111_1, permute_5, view_32, view_8, permute_6, split_1, view_30, arg113_1, view_42, where, view_18, div_9}
)
)
args[1]: TensorBox(
View(
PermuteView(data=View(
StorageBox(
Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf44, i2 + 96 * i0)
tmp1 = load(arg7_1, 96 + i2)
tmp2 = tmp0 + tmp1
return tmp2
,
ranges=[32768, 1, 96],
origins={arg7_1, add_11, add_10, view_35, view_3, mm, view_36, split_3, permute_13, permute_1, permute_2, arg110_1, arg6_1, split_2}
)
),
size=(32768, 8, 12),
reindex=lambda i0, i1, i2: [i0, 0, 12*i1 + i2],
origins={arg7_1, add_11, add_10, view_3, view_35, mm, view_36, split_3, permute_13, permute_1, permute_2, arg110_1, arg6_1, split_2, view_40}
), dims=[1, 0, 2]),
size=(1, 8, 32768, 12),
reindex=lambda i0, i1, i2, i3: [i1, i2, i3],
origins={arg7_1, view_43, view_35, mm, permute_13, permute_1, arg110_1, arg6_1, split_2, permute_16, add_11, add_10, view_3, view_36, split_3, permute_2, view_40}
)
)
args[2]: TensorBox(
View(
PermuteView(data=View(
StorageBox(
Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf45, i2 + 96 * i0)
tmp1 = load(arg7_1, 192 + i2)
tmp2 = tmp0 + tmp1
return tmp2
,
ranges=[32768, 1, 96],
origins={arg7_1, add_12, view_3, split_3, view_37, permute_14, mm_1, permute_1, arg6_1, view_38, arg110_1, split_2}
)
),
size=(32768, 8, 12),
reindex=lambda i0, i1, i2: [i0, 0, 12*i1 + i2],
origins={arg7_1, add_12, view_3, split_3, view_37, permute_14, mm_1, permute_1, view_41, arg6_1, view_38, arg110_1, split_2}
), dims=[1, 0, 2]),
size=(1, 8, 32768, 12),
reindex=lambda i0, i1, i2, i3: [i1, i2, i3],
origins={arg7_1, add_12, view_3, split_3, view_37, permute_14, view_44, mm_1, permute_1, view_41, arg6_1, view_38, arg110_1, permute_17, split_2}
)
)
args[3]: False
...
[2023-12-27 19:37:50,512] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 158
[2023-12-27 19:37:50,546] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 158
[2023-12-27 19:37:50,547] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-12-27 19:37:50,571] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in create_list_from_pred>
Eval stats: ...
Any idea, how to make it work or why does pytorch 2.1.1 hangs?