Torch.compile on train_step incl both fwd and bwd

Should I attempt to compile a train_step function that includes both the forward pass and loss.backward()? When running a simple example with TORCH_LOGS=“aot_graphs"it appears there is still a graph break on loss.backward and separate fwd and bwd graphs generated.

    @torch.compile(fullgraph=False)
    def train_step(model, data):
        out = model(data)
        loss = out.sum()
        loss.backward()
        return out

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    _ = train_step(model, data)
    optimizer.step()
    optimizer.zero_grad()
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1475] [1/0_1] [__aot_graphs] aot_config id: 1, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=True), InputAliasInfo(is_leaf=False, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=True, view_meta_sequence=None), OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=True, view_meta_sequence=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[FakeTensor(..., size=()), FakeTensor(..., size=(8, 256))], traced_tangents_descs=[TangentAOTInput(output=PlainAOTOutput(idx=0)), TangentAOTInput(output=PlainAOTOutput(idx=1))], subclass_inp_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None), PlainTensorMeta(unwrapped_idx=1, memory_format=None)], subclass_fw_graph_out_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None), PlainTensorMeta(unwrapped_idx=1, memory_format=None)], subclass_tangent_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=MemoryFormatMeta(size=None, stride=None, memory_format=torch.contiguous_format)), PlainTensorMeta(unwrapped_idx=1, memory_format=MemoryFormatMeta(size=None, stride=None, memory_format=torch.contiguous_format))], is_train=True, traced_tangent_metas=None, num_symints_saved_for_bw=0, grad_enabled_mutation=None, deterministic=False, static_input_indices=[0], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=[], num_backward_tokens=0, num_graphsafe_rng_states=0, graphsafe_rng_state_index=None), inner_meta=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=True), InputAliasInfo(is_leaf=False, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=True, view_meta_sequence=None), OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=True, view_meta_sequence=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[FakeTensor(..., size=()), FakeTensor(..., size=(8, 256))], traced_tangents_descs=[TangentAOTInput(output=PlainAOTOutput(idx=0)), TangentAOTInput(output=PlainAOTOutput(idx=1))], subclass_inp_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None), PlainTensorMeta(unwrapped_idx=1, memory_format=None)], subclass_fw_graph_out_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None), PlainTensorMeta(unwrapped_idx=1, memory_format=None)], subclass_tangent_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=MemoryFormatMeta(size=None, stride=None, memory_format=torch.contiguous_format)), PlainTensorMeta(unwrapped_idx=1, memory_format=MemoryFormatMeta(size=None, stride=None, memory_format=torch.contiguous_format))], is_train=True, traced_tangent_metas=None, num_symints_saved_for_bw=0, grad_enabled_mutation=None, deterministic=False, static_input_indices=[0], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=[], num_backward_tokens=0, num_graphsafe_rng_states=0, graphsafe_rng_state_index=None)
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs] TRACED GRAPH
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]  ===== Forward graph 1 =====
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]  /opt/conda/envs/tt/lib/python3.12/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]     def forward(self, primals_1: "f32[1][1]cpu", primals_2: "f32[8, 256][256, 1]cpu"):
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]          # File: /home/yho_google_com/Documents/GitHub/sandbox/scripts/example_compile_lowerings.py:380 in forward, code: return (self.scale * x).sin()
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]         mul: "f32[8, 256][256, 1]cpu" = torch.ops.aten.mul.Tensor(primals_1, primals_2)
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]         sin: "f32[8, 256][256, 1]cpu" = torch.ops.aten.sin.default(mul);  mul = None
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]         
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]          # File: /home/yho_google_com/Documents/GitHub/sandbox/scripts/example_compile_lowerings.py:419 in train_step, code: loss = out.sum()
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]         sum_1: "f32[][]cpu" = torch.ops.aten.sum.default(sin)
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]         return (sum_1, sin, primals_1, primals_2)
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs]         
I0121 18:37:45.987000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1581] [1/0_1] [__aot_graphs] 
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs] TRACED GRAPH
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]  ===== Backward graph 1 =====
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]  <eval_with_key>.8 class GraphModule(torch.nn.Module):
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]     def forward(self, primals_1: "f32[1][1]cpu", primals_2: "f32[8, 256][256, 1]cpu", tangents_1: "f32[][]cpu", tangents_2: "f32[8, 256][256, 1]cpu"):
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]          # File: /home/yho_google_com/Documents/GitHub/sandbox/scripts/example_compile_lowerings.py:419 in train_step, code: loss = out.sum()
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         expand: "f32[8, 256][0, 0]cpu" = torch.ops.aten.expand.default(tangents_1, [8, 256]);  tangents_1 = None
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]          # File: /home/yho_google_com/Documents/GitHub/sandbox/scripts/example_compile_lowerings.py:419 in train_step, code: loss = out.sum()
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         add: "f32[8, 256][256, 1]cpu" = torch.ops.aten.add.Tensor(tangents_2, expand);  tangents_2 = expand = None
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]          # File: /home/yho_google_com/Documents/GitHub/sandbox/scripts/example_compile_lowerings.py:380 in forward, code: return (self.scale * x).sin()
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         mul: "f32[8, 256][256, 1]cpu" = torch.ops.aten.mul.Tensor(primals_1, primals_2)
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         cos: "f32[8, 256][256, 1]cpu" = torch.ops.aten.cos.default(mul);  mul = None
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         mul_1: "f32[8, 256][256, 1]cpu" = torch.ops.aten.mul.Tensor(add, cos);  add = cos = None
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         mul_2: "f32[8, 256][256, 1]cpu" = torch.ops.aten.mul.Tensor(mul_1, primals_1);  primals_1 = None
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         mul_3: "f32[8, 256][256, 1]cpu" = torch.ops.aten.mul.Tensor(mul_1, primals_2);  mul_1 = primals_2 = None
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         sum_2: "f32[1, 1][1, 1]cpu" = torch.ops.aten.sum.dim_IntList(mul_3, [0, 1], True);  mul_3 = None
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         view: "f32[1][1]cpu" = torch.ops.aten.view.default(sum_2, [1]);  sum_2 = None
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         return (view, mul_2)
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs]         
I0121 18:37:45.988000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_compile.py:1592] [1/0_1] [__aot_graphs] 
W0121 18:37:47.321000 3009140 site-packages/torch/_dynamo/utils.py:1915] [!0] ChromiumEventLogger: Start event not in stack, ignoring
V0121 18:37:47.426000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:203] [!0/5/0] [__aot_graphs] aot_config id: 2, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True), InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=True), InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True), InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=True), InputAliasInfo(is_leaf=True, mutates_data=True, mutates_metadata=False, mutations_hidden_from_autograd=False, mutations_under_no_grad_or_inference_mode=True, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=False, view_meta_sequence=None), OutputAliasInfo(output_type=<OutputType.is_input: 3>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=4, dynamic_dims=set(), requires_grad=False, view_meta_sequence=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[], traced_tangents_descs=[], subclass_inp_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None), PlainTensorMeta(unwrapped_idx=1, memory_format=None), PlainTensorMeta(unwrapped_idx=2, memory_format=None), PlainTensorMeta(unwrapped_idx=3, memory_format=None), PlainTensorMeta(unwrapped_idx=4, memory_format=None)], subclass_fw_graph_out_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None), PlainTensorMeta(unwrapped_idx=1, memory_format=None)], subclass_tangent_meta=[], is_train=False, traced_tangent_metas=None, num_symints_saved_for_bw=None, grad_enabled_mutation=None, deterministic=False, static_input_indices=[1], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None, num_backward_tokens=0, num_graphsafe_rng_states=0, graphsafe_rng_state_index=None),subclass_metadata=None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs] TRACED GRAPH
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]  ===== Forward graph 2 =====
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]  /opt/conda/envs/tt/lib/python3.12/site-packages/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]     def forward(self, arg0_1: "f32[][]cpu", arg1_1: "f32[1][1]cpu", arg2_1: "f32[8, 256][256, 1]cpu", arg3_1: "f32[2048][1]cpu", arg4_1: "f32[2048][1]cpu"):
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         # No stacktrace found for following nodes
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         full: "f32[8, 256][256, 1]cpu" = torch.ops.aten.full.default([8, 256], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         expand: "f32[8, 256][0, 0]cpu" = torch.ops.aten.expand.default(arg0_1, [8, 256]);  arg0_1 = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         add: "f32[8, 256][256, 1]cpu" = torch.ops.aten.add.Tensor(full, expand);  full = expand = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         mul: "f32[8, 256][256, 1]cpu" = torch.ops.aten.mul.Tensor(arg1_1, arg2_1)
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         cos: "f32[8, 256][256, 1]cpu" = torch.ops.aten.cos.default(mul);  mul = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         mul_1: "f32[8, 256][256, 1]cpu" = torch.ops.aten.mul.Tensor(add, cos);  add = cos = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         mul_2: "f32[8, 256][256, 1]cpu" = torch.ops.aten.mul.Tensor(mul_1, arg1_1);  arg1_1 = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         mul_3: "f32[8, 256][256, 1]cpu" = torch.ops.aten.mul.Tensor(mul_1, arg2_1);  mul_1 = arg2_1 = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         sum_1: "f32[1, 1][1, 1]cpu" = torch.ops.aten.sum.dim_IntList(mul_3, [0, 1], True);  mul_3 = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         view: "f32[1][1]cpu" = torch.ops.aten.view.default(sum_1, [1]);  sum_1 = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         new_empty_strided: "f32[1][1]cpu" = torch.ops.aten.new_empty_strided.default(view, [1], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         copy: "f32[1][1]cpu" = torch.ops.aten.copy.default(new_empty_strided, view);  new_empty_strided = view = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         view_1: "f32[2048][1]cpu" = torch.ops.aten.view.default(mul_2, [2048]);  mul_2 = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         add_1: "f32[2048][1]cpu" = torch.ops.aten.add.Tensor(arg4_1, view_1);  view_1 = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         copy_: "f32[2048][1]cpu" = torch.ops.aten.copy_.default(arg4_1, add_1);  arg4_1 = add_1 = None
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         return (copy, copy_)
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs]         
I0121 18:37:47.457000 3009140 site-packages/torch/_functorch/_aot_autograd/graph_capture.py:289] [!0/5/0] [__aot_graphs] 

If I run the code with TORCH_LOGS=”graph_breaks”, I see this statement indicating loss.backwards will always case a graph break.

FYI, when I write the train_step using torch.func, I do see an aot-graph with fwd and bwd in the same graph.

V0121 18:41:27.233000 3009873 site-packages/torch/_dynamo/symbolic_convert.py:611] [1/0] [__graph_breaks] Graph Break Reason: Unsupported Tensor.backward() call
V0121 18:41:27.233000 3009873 site-packages/torch/_dynamo/symbolic_convert.py:611] [1/0] [__graph_breaks]   Explanation: Dynamo currently does not support tracing `Tensor.backward()`.
V0121 18:41:27.233000 3009873 site-packages/torch/_dynamo/symbolic_convert.py:611] [1/0] [__graph_breaks]   Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.