Torch.compile Error: RuntimeError: aten::_conj() Expected a value of type 'Tensor' for argument 'self' but instead found type 'complex'.

Training code

manual_seed(args.seed)
torch.backends.cudnn.benchmark = True

with open(args.model_path+'/config.yaml') as f:
    config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
config.training.num_steps = args.num_steps

trainset = MSSDatasets(config, args.data_root)

train_loader = DataLoader(
    trainset, 
    batch_size=config.training.batch_size, 
    shuffle=True, 
    num_workers=args.num_workers, 
    pin_memory=args.pin_memory
)

model = TFC_TDF_net(config)
model = torch.compile(model)
model.train()

device_ids = args.device_ids
if type(device_ids)==int:
    device = torch.device(f'cuda:{device_ids}')
    model = model.to(device)
else:
    device = torch.device(f'cuda:{device_ids[0]}')
    model = nn.DataParallel(model, device_ids=device_ids).to(device)

optimizer = Adam(model.parameters(), lr=config.training.lr)

print('Train Loop')
scaler = GradScaler()    
for batch in tqdm(train_loader):   

    y = batch.to(device)
    x = y.sum(1)  # mixture   
    if config.training.target_instrument is not None:
        i = config.training.instruments.index(config.training.target_instrument)
        y = y[:,i]
    with torch.cuda.amp.autocast():        
        y_ = model(x)   
        loss = nn.MSELoss()(y_, y) 

    scaler.scale(loss).backward()
    if config.training.grad_clip:
        nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip)  
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)


state_dict = model.state_dict() if type(device_ids)==int else model.module.state_dict()

torch.save(state_dict, args.model_path+'/ckpt')

if name == “main”:
train()`

Model code

> class STFT:
>     def __init__(self, config):
>         self.n_fft = config.n_fft
>         self.hop_length = config.hop_length
>         self.window = torch.hann_window(window_length=self.n_fft, periodic=True)        
>         self.dim_f = config.dim_f
>     
>     def __call__(self, x):
>         window = self.window.to(x.device)
>         batch_dims = x.shape[:-2]
>         c, t = x.shape[-2:]
>         x = x.reshape([-1, t])
>         x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=False)
>         x = x.permute([0,3,1,2])
>         x = x.reshape([*batch_dims,c,2,-1,x.shape[-1]]).reshape([*batch_dims,c*2,-1,x.shape[-1]])
>         return x[...,:self.dim_f,:]
> 
>     def inverse(self, x):
>         window = self.window.to(x.device)
>         batch_dims = x.shape[:-3]
>         c,f,t = x.shape[-3:]
>         n = self.n_fft//2+1
>         f_pad = torch.zeros([*batch_dims,c,n-f,t]).to(x.device)
>         x = torch.cat([x, f_pad], -2)
>         x = x.reshape([*batch_dims,c//2,2,n,t]).reshape([-1,2,n,t])
>         x = x.permute([0,2,3,1])
>         x = x[...,0] + x[...,1] * 1.j
>         x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
>         x = x.reshape([*batch_dims,2,-1])
>         return x
> 
>     
> def get_norm(norm_type):
>     def norm(c, norm_type):   
>         if norm_type=='BatchNorm':
>             return nn.BatchNorm2d(c)
>         elif norm_type=='InstanceNorm':
>             return nn.InstanceNorm2d(c, affine=True)
>         elif 'GroupNorm' in norm_type:
>             g = int(norm_type.replace('GroupNorm', ''))
>             return nn.GroupNorm(num_groups=g, num_channels=c)
>         else:
>             return nn.Identity()
>     return partial(norm, norm_type=norm_type)
> 
> 
> def get_act(act_type):
>     if act_type=='gelu':
>         return nn.GELU()
>     elif act_type=='relu':
>         return nn.ReLU()
>     elif act_type[:3]=='elu':
>         alpha = float(act_type.replace('elu', ''))
>         return nn.ELU(alpha)
>     else:
>         raise Exception
> 
>         
> class Upscale(nn.Module):
>     def __init__(self, in_c, out_c, scale, norm, act):
>         super().__init__()
>         self.conv = nn.Sequential(
>             norm(in_c),
>             act,  
>             nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
>         )
>                                   
>     def forward(self, x):
>         return self.conv(x)
> 
> 
> class Downscale(nn.Module):
>     def __init__(self, in_c, out_c, scale, norm, act):
>         super().__init__()
>         self.conv = nn.Sequential(
>             norm(in_c),
>             act,   
>             nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
>         )
>                                   
>     def forward(self, x):
>         return self.conv(x)
> 
> 
> class TFC_TDF(nn.Module):
>     def __init__(self, in_c, c, l, f, bn, norm, act):        
>         super().__init__()
> 
>         self.blocks = nn.ModuleList()
>         for i in range(l): 
>             block = nn.Module()
>             
>             block.tfc1 = nn.Sequential(
>                 norm(in_c),
>                 act,
>                 nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
>             )
>             block.tdf = nn.Sequential(
>                 norm(c),
>                 act,
>                 nn.Linear(f, f//bn, bias=False),
>                 norm(c),
>                 act,
>                 nn.Linear(f//bn, f, bias=False),
>             )
>             block.tfc2 = nn.Sequential(
>                 norm(c),
>                 act,
>                 nn.Conv2d(c, c, 3, 1, 1, bias=False),
>             )
>             block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
>             
>             self.blocks.append(block)
>             in_c = c
>               
>     def forward(self, x):
>         for block in self.blocks:
>             s = block.shortcut(x)
>             x = block.tfc1(x)
>             x = x + block.tdf(x)
>             x = block.tfc2(x)
>             x = x + s
>         return x
> 
> 
> class TFC_TDF_net(nn.Module):
>     def __init__(self, config):
>         super().__init__()
>         self.config = config
>         
>         norm = get_norm(norm_type=config.model.norm)
>         act = get_act(act_type=config.model.act)
>         
>         self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
>         self.num_subbands = config.model.num_subbands
>         
>         dim_c = self.num_subbands * config.audio.num_channels * 2         
>         n = config.model.num_scales
>         scale = config.model.scale
>         l = config.model.num_blocks_per_scale 
>         c = config.model.num_channels
>         g = config.model.growth
>         bn = config.model.bottleneck_factor               
>         f = config.audio.dim_f // self.num_subbands
>         
>         self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
>  
>         self.encoder_blocks = nn.ModuleList()
>         for i in range(n):
>             block = nn.Module()
>             block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
>             block.downscale = Downscale(c, c+g, scale, norm, act) 
>             f = f//scale[1]
>             c += g
>             self.encoder_blocks.append(block)                
>                
>         self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
>         
>         self.decoder_blocks = nn.ModuleList()
>         for i in range(n):                
>             block = nn.Module()
>             block.upscale = Upscale(c, c-g, scale, norm, act)
>             f = f*scale[1]
>             c -= g  
>             block.tfc_tdf = TFC_TDF(2*c, c, l, f, bn, norm, act)
>             self.decoder_blocks.append(block) 
>               
>         self.final_conv = nn.Sequential(
>             nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
>             act,
>             nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
>         )
>         
>         self.stft = STFT(config.audio)
>     
>     def cac2cws(self, x):
>         k = self.num_subbands
>         b,c,f,t = x.shape
>         x = x.reshape(b,c,k,f//k,t)
>         x = x.reshape(b,c*k,f//k,t)
>         return x
>     
>     def cws2cac(self, x):
>         k = self.num_subbands
>         b,c,f,t = x.shape
>         x = x.reshape(b,c//k,k,f,t)
>         x = x.reshape(b,c//k,f*k,t)
>         return x
>     
>     def forward(self, x):
>         
>         x = self.stft(x)
>         
>         mix = x = self.cac2cws(x)
>         
>         first_conv_out = x = self.first_conv(x)
> 
>         x = x.transpose(-1,-2)
>         
>         encoder_outputs = []
>         for block in self.encoder_blocks:  
>             x = block.tfc_tdf(x) 
>             encoder_outputs.append(x)
>             x = block.downscale(x)              
>             
>         x = self.bottleneck_block(x)
>         
>         for block in self.decoder_blocks:            
>             x = block.upscale(x)
>             x = torch.cat([x, encoder_outputs.pop()], 1)
>             x = block.tfc_tdf(x) 
>             
>         x = x.transpose(-1,-2)
>         
>         x = x * first_conv_out  # reduce artifacts
>         
>         x = self.final_conv(torch.cat([mix, x], 1))
>         
>         x = self.cws2cac(x)
>         
>         if self.num_target_instruments > 1:
>             b,c,f,t = x.shape
>             x = x.reshape(b,self.num_target_instruments,-1,f,t)
>         
>         x = self.stft.inverse(x)
>         
>         return x

Error logs

  0% 0/1000000 [00:00<?, ?it/s][2023-07-16 14:24:31,474] torch._inductor.utils: [WARNING] DeviceCopy in input program
  0% 0/1000000 [01:07<?, ?it/s]
Traceback (most recent call last):
  File "/content/sdx23/my_submission/src/train.py", line 120, in <module>
    train()
  File "/content/sdx23/my_submission/src/train.py", line 91, in train
    out = model(x)
       ^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1531, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 183, in forward
    return self.module(*inputs[0], **module_kwargs[0])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1531, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1531, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/sdx23/my_submission/src/tfc_tdf_v3.py", line 196, in forward
    def forward(self, x):
  File "/content/sdx23/my_submission/src/tfc_tdf_v3.py", line 198, in <resume in forward>
    x = self.stft(x)
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 447, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 128, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 364, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 179, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 434, in _compile
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1002, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 419, in transform
    tracer.run()
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2068, in run
    super().run()
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 727, in run
    and self.step()
        ^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 687, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 441, in wrapper
    self.output.compile_subgraph(self, reason=reason)
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 815, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/usr/local/envs/mdx-net/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 915, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 179, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 971, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 967, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/__init__.py", line 1548, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1045, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 3750, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 179, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 3289, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2098, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2278, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2686, in aot_dispatch_autograd
    fx_g = aot_dispatch_autograd_graph(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2663, in aot_dispatch_autograd_graph
    fx_g = create_functionalized_graph(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1399, in create_functionalized_graph
    fx_g = make_fx(helper, decomposition_table=aot_config.decompositions)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 809, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 468, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 817, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 684, in flatten_fn
    tree_out = root_fn(*tree_args)
               ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 485, in wrapped
    out = f(*tensors)
          ^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1388, in joint_helper
    return functionalized_f_helper(primals, tangents)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1341, in functionalized_f_helper
    f_outs = fn(*f_args)
             ^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1312, in inner_fn_with_anomaly
    return inner_fn(*args)
           ^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1295, in inner_fn
    backward_out = torch.autograd.grad(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/autograd/__init__.py", line 319, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 555, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 580, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 361, in proxy_call
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_ops.py", line 437, in __call__
    return self._op(*args, **kwargs or {})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: aten::_conj() Expected a value of type 'Tensor' for argument 'self' but instead found type 'complex'.
Position: 0
Value: 1j
Declaration: aten::_conj(Tensor(a) self) -> Tensor(a)
Cast error details: Unable to cast 1j to Tensor


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

# packages in environment at /usr/local/envs/mdx-net:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
absl-py                   1.4.0                    pypi_0    pypi
antlr4-python3-runtime    4.9.3                    pypi_0    pypi
attrs                     23.1.0                   pypi_0    pypi
bzip2                     1.0.8                h7f98852_4    conda-forge
ca-certificates           2023.5.7             hbcca054_0    conda-forge
certifi                   2023.5.7                 pypi_0    pypi
cffi                      1.15.1                   pypi_0    pypi
charset-normalizer        3.2.0                    pypi_0    pypi
click                     8.1.5                    pypi_0    pypi
cloudpickle               2.2.1                    pypi_0    pypi
cmake                     3.26.4                   pypi_0    pypi
contextlib2               21.6.0                   pypi_0    pypi
cython                    0.29.36                  pypi_0    pypi
demucs                    4.0.0                    pypi_0    pypi
diffq                     0.2.4                    pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
dora-search               0.1.12                   pypi_0    pypi
einops                    0.6.1                    pypi_0    pypi
ffmpeg-python             0.2.0                    pypi_0    pypi
filelock                  3.12.2                   pypi_0    pypi
fsspec                    2023.4.0                 pypi_0    pypi
future                    0.18.3                   pypi_0    pypi
gitdb                     4.0.10                   pypi_0    pypi
gitpython                 3.1.32                   pypi_0    pypi
idna                      3.4                      pypi_0    pypi
jinja2                    3.1.2                    pypi_0    pypi
jsonschema                4.18.3                   pypi_0    pypi
jsonschema-specifications 2023.6.1                 pypi_0    pypi
julius                    0.2.7                    pypi_0    pypi
lameenc                   1.5.1                    pypi_0    pypi
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
libexpat                  2.5.0                hcb278e6_1    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.1.0               he5830b7_0    conda-forge
libgomp                   13.1.0               he5830b7_0    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libsqlite                 3.42.0               h2797004_0    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
lit                       16.0.6                   pypi_0    pypi
markupsafe                2.1.3                    pypi_0    pypi
mir-eval                  0.7                      pypi_0    pypi
ml-collections            0.1.1                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
musdb                     0.4.0                    pypi_0    pypi
museval                   0.4.1                    pypi_0    pypi
ncurses                   6.4                  hcb278e6_0    conda-forge
networkx                  3.1                      pypi_0    pypi
numpy                     1.25.1                   pypi_0    pypi
nvidia-cublas-cu11        11.10.3.66               pypi_0    pypi
nvidia-cuda-cupti-cu11    11.7.101                 pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.7.99                  pypi_0    pypi
nvidia-cuda-runtime-cu11  11.7.99                  pypi_0    pypi
nvidia-cudnn-cu11         8.5.0.96                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-curand-cu11        10.2.10.91               pypi_0    pypi
nvidia-cusolver-cu11      11.4.0.1                 pypi_0    pypi
nvidia-cusparse-cu11      11.7.4.91                pypi_0    pypi
nvidia-nccl-cu11          2.14.3                   pypi_0    pypi
nvidia-nvtx-cu11          11.7.91                  pypi_0    pypi
omegaconf                 2.3.0                    pypi_0    pypi
openssl                   3.1.1                hd590300_1    conda-forge
openunmix                 1.2.1                    pypi_0    pypi
pandas                    2.0.3                    pypi_0    pypi
pathtools                 0.1.2                    pypi_0    pypi
pip                       23.2               pyhd8ed1ab_0    conda-forge
promise                   2.3                      pypi_0    pypi
protobuf                  3.20.3                   pypi_0    pypi
psutil                    5.9.5                    pypi_0    pypi
pyaml                     23.7.0                   pypi_0    pypi
pycparser                 2.21                     pypi_0    pypi
python                    3.11.4          hab00c5b_0_cpython    conda-forge
python-dateutil           2.8.2                    pypi_0    pypi
pytorch-triton            2.1.0+3c400e7818          pypi_0    pypi
pytz                      2023.3                   pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
readline                  8.2                  h8228510_1    conda-forge
referencing               0.29.1                   pypi_0    pypi
requests                  2.31.0                   pypi_0    pypi
retrying                  1.3.4                    pypi_0    pypi
rpds-py                   0.8.10                   pypi_0    pypi
scipy                     1.11.1                   pypi_0    pypi
sentry-sdk                1.28.1                   pypi_0    pypi
setproctitle              1.3.2                    pypi_0    pypi
setuptools                68.0.0             pyhd8ed1ab_0    conda-forge
shortuuid                 1.0.11                   pypi_0    pypi
simplejson                3.19.1                   pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
smmap                     5.0.0                    pypi_0    pypi
soundfile                 0.12.1                   pypi_0    pypi
stempeg                   0.2.3                    pypi_0    pypi
submitit                  1.4.5                    pypi_0    pypi
sympy                     1.12                     pypi_0    pypi
tk                        8.6.12               h27826a3_0    conda-forge
torch                     2.1.0.dev20230716+cu118          pypi_0    pypi
torchaudio                2.0.2                    pypi_0    pypi
tqdm                      4.65.0                   pypi_0    pypi
treetable                 0.2.5                    pypi_0    pypi
triton                    2.0.0                    pypi_0    pypi
typing-extensions         4.7.1                    pypi_0    pypi
tzdata                    2023.3                   pypi_0    pypi
urllib3                   2.0.3                    pypi_0    pypi
wandb                     0.13.2                   pypi_0    pypi
wheel                     0.40.0             pyhd8ed1ab_1    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge

issue was x-posted here Torch.compile Error: RuntimeError: aten::_conj() Expected a value of type 'Tensor' for argument 'self' but instead found type 'complex'. · Issue #105290 · pytorch/pytorch · GitHub

1 Like

Thank you, do you have any suggestions for a fix @marksaroufim

I hope its ok to tag you @ptrblck since there hasnt been activity on this thread or my github issue post, maybe you have any suggestions?

I believe devs are still looking into the issue and I’m not familiar enough with the complex support in torch.compile, unfortunately, so won’t be a huge help here.

You can follow the discussion here Compiling complex-valued functions fails · Issue #98161 · pytorch/pytorch · GitHub

I don’t think complex support is super high on the roadmap but feel free to disagree there with your use case. In the meantime it might be faster to make progrerss by rewriting your model to not use complex numbers and instead use 2 reals

EDIT: In hindsight my advice here was not useful, i tried doing this in the linked github thread and my code is insanely slow