Training code
manual_seed(args.seed)
torch.backends.cudnn.benchmark = True
with open(args.model_path+'/config.yaml') as f:
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
config.training.num_steps = args.num_steps
trainset = MSSDatasets(config, args.data_root)
train_loader = DataLoader(
trainset,
batch_size=config.training.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=args.pin_memory
)
model = TFC_TDF_net(config)
model = torch.compile(model)
model.train()
device_ids = args.device_ids
if type(device_ids)==int:
device = torch.device(f'cuda:{device_ids}')
model = model.to(device)
else:
device = torch.device(f'cuda:{device_ids[0]}')
model = nn.DataParallel(model, device_ids=device_ids).to(device)
optimizer = Adam(model.parameters(), lr=config.training.lr)
print('Train Loop')
scaler = GradScaler()
for batch in tqdm(train_loader):
y = batch.to(device)
x = y.sum(1) # mixture
if config.training.target_instrument is not None:
i = config.training.instruments.index(config.training.target_instrument)
y = y[:,i]
with torch.cuda.amp.autocast():
y_ = model(x)
loss = nn.MSELoss()(y_, y)
scaler.scale(loss).backward()
if config.training.grad_clip:
nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
state_dict = model.state_dict() if type(device_ids)==int else model.module.state_dict()
torch.save(state_dict, args.model_path+'/ckpt')
if name == “main”:
train()`
Model code
> class STFT:
> def __init__(self, config):
> self.n_fft = config.n_fft
> self.hop_length = config.hop_length
> self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
> self.dim_f = config.dim_f
>
> def __call__(self, x):
> window = self.window.to(x.device)
> batch_dims = x.shape[:-2]
> c, t = x.shape[-2:]
> x = x.reshape([-1, t])
> x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=False)
> x = x.permute([0,3,1,2])
> x = x.reshape([*batch_dims,c,2,-1,x.shape[-1]]).reshape([*batch_dims,c*2,-1,x.shape[-1]])
> return x[...,:self.dim_f,:]
>
> def inverse(self, x):
> window = self.window.to(x.device)
> batch_dims = x.shape[:-3]
> c,f,t = x.shape[-3:]
> n = self.n_fft//2+1
> f_pad = torch.zeros([*batch_dims,c,n-f,t]).to(x.device)
> x = torch.cat([x, f_pad], -2)
> x = x.reshape([*batch_dims,c//2,2,n,t]).reshape([-1,2,n,t])
> x = x.permute([0,2,3,1])
> x = x[...,0] + x[...,1] * 1.j
> x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
> x = x.reshape([*batch_dims,2,-1])
> return x
>
>
> def get_norm(norm_type):
> def norm(c, norm_type):
> if norm_type=='BatchNorm':
> return nn.BatchNorm2d(c)
> elif norm_type=='InstanceNorm':
> return nn.InstanceNorm2d(c, affine=True)
> elif 'GroupNorm' in norm_type:
> g = int(norm_type.replace('GroupNorm', ''))
> return nn.GroupNorm(num_groups=g, num_channels=c)
> else:
> return nn.Identity()
> return partial(norm, norm_type=norm_type)
>
>
> def get_act(act_type):
> if act_type=='gelu':
> return nn.GELU()
> elif act_type=='relu':
> return nn.ReLU()
> elif act_type[:3]=='elu':
> alpha = float(act_type.replace('elu', ''))
> return nn.ELU(alpha)
> else:
> raise Exception
>
>
> class Upscale(nn.Module):
> def __init__(self, in_c, out_c, scale, norm, act):
> super().__init__()
> self.conv = nn.Sequential(
> norm(in_c),
> act,
> nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
> )
>
> def forward(self, x):
> return self.conv(x)
>
>
> class Downscale(nn.Module):
> def __init__(self, in_c, out_c, scale, norm, act):
> super().__init__()
> self.conv = nn.Sequential(
> norm(in_c),
> act,
> nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
> )
>
> def forward(self, x):
> return self.conv(x)
>
>
> class TFC_TDF(nn.Module):
> def __init__(self, in_c, c, l, f, bn, norm, act):
> super().__init__()
>
> self.blocks = nn.ModuleList()
> for i in range(l):
> block = nn.Module()
>
> block.tfc1 = nn.Sequential(
> norm(in_c),
> act,
> nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
> )
> block.tdf = nn.Sequential(
> norm(c),
> act,
> nn.Linear(f, f//bn, bias=False),
> norm(c),
> act,
> nn.Linear(f//bn, f, bias=False),
> )
> block.tfc2 = nn.Sequential(
> norm(c),
> act,
> nn.Conv2d(c, c, 3, 1, 1, bias=False),
> )
> block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
>
> self.blocks.append(block)
> in_c = c
>
> def forward(self, x):
> for block in self.blocks:
> s = block.shortcut(x)
> x = block.tfc1(x)
> x = x + block.tdf(x)
> x = block.tfc2(x)
> x = x + s
> return x
>
>
> class TFC_TDF_net(nn.Module):
> def __init__(self, config):
> super().__init__()
> self.config = config
>
> norm = get_norm(norm_type=config.model.norm)
> act = get_act(act_type=config.model.act)
>
> self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
> self.num_subbands = config.model.num_subbands
>
> dim_c = self.num_subbands * config.audio.num_channels * 2
> n = config.model.num_scales
> scale = config.model.scale
> l = config.model.num_blocks_per_scale
> c = config.model.num_channels
> g = config.model.growth
> bn = config.model.bottleneck_factor
> f = config.audio.dim_f // self.num_subbands
>
> self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
>
> self.encoder_blocks = nn.ModuleList()
> for i in range(n):
> block = nn.Module()
> block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
> block.downscale = Downscale(c, c+g, scale, norm, act)
> f = f//scale[1]
> c += g
> self.encoder_blocks.append(block)
>
> self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
>
> self.decoder_blocks = nn.ModuleList()
> for i in range(n):
> block = nn.Module()
> block.upscale = Upscale(c, c-g, scale, norm, act)
> f = f*scale[1]
> c -= g
> block.tfc_tdf = TFC_TDF(2*c, c, l, f, bn, norm, act)
> self.decoder_blocks.append(block)
>
> self.final_conv = nn.Sequential(
> nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
> act,
> nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
> )
>
> self.stft = STFT(config.audio)
>
> def cac2cws(self, x):
> k = self.num_subbands
> b,c,f,t = x.shape
> x = x.reshape(b,c,k,f//k,t)
> x = x.reshape(b,c*k,f//k,t)
> return x
>
> def cws2cac(self, x):
> k = self.num_subbands
> b,c,f,t = x.shape
> x = x.reshape(b,c//k,k,f,t)
> x = x.reshape(b,c//k,f*k,t)
> return x
>
> def forward(self, x):
>
> x = self.stft(x)
>
> mix = x = self.cac2cws(x)
>
> first_conv_out = x = self.first_conv(x)
>
> x = x.transpose(-1,-2)
>
> encoder_outputs = []
> for block in self.encoder_blocks:
> x = block.tfc_tdf(x)
> encoder_outputs.append(x)
> x = block.downscale(x)
>
> x = self.bottleneck_block(x)
>
> for block in self.decoder_blocks:
> x = block.upscale(x)
> x = torch.cat([x, encoder_outputs.pop()], 1)
> x = block.tfc_tdf(x)
>
> x = x.transpose(-1,-2)
>
> x = x * first_conv_out # reduce artifacts
>
> x = self.final_conv(torch.cat([mix, x], 1))
>
> x = self.cws2cac(x)
>
> if self.num_target_instruments > 1:
> b,c,f,t = x.shape
> x = x.reshape(b,self.num_target_instruments,-1,f,t)
>
> x = self.stft.inverse(x)
>
> return x
Error logs
0% 0/1000000 [00:00<?, ?it/s][2023-07-16 14:24:31,474] torch._inductor.utils: [WARNING] DeviceCopy in input program
0% 0/1000000 [01:07<?, ?it/s]
Traceback (most recent call last):
File "/content/sdx23/my_submission/src/train.py", line 120, in <module>
train()
File "/content/sdx23/my_submission/src/train.py", line 91, in train
out = model(x)
^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1531, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 183, in forward
return self.module(*inputs[0], **module_kwargs[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1531, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1531, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/content/sdx23/my_submission/src/tfc_tdf_v3.py", line 196, in forward
def forward(self, x):
File "/content/sdx23/my_submission/src/tfc_tdf_v3.py", line 198, in <resume in forward>
x = self.stft(x)
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 447, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 128, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 364, in _convert_frame_assert
return _compile(
^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 179, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 434, in _compile
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1002, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 419, in transform
tracer.run()
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2068, in run
super().run()
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 727, in run
and self.step()
^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 687, in step
getattr(self, inst.opname)(inst)
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 441, in wrapper
self.output.compile_subgraph(self, reason=reason)
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 815, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/usr/local/envs/mdx-net/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 915, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 179, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 971, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 967, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/__init__.py", line 1548, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1045, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 3750, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 179, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 3289, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2098, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2278, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2686, in aot_dispatch_autograd
fx_g = aot_dispatch_autograd_graph(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2663, in aot_dispatch_autograd_graph
fx_g = create_functionalized_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1399, in create_functionalized_graph
fx_g = make_fx(helper, decomposition_table=aot_config.decompositions)(*args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 809, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 468, in dispatch_trace
graph = tracer.trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 817, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 684, in flatten_fn
tree_out = root_fn(*tree_args)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 485, in wrapped
out = f(*tensors)
^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1388, in joint_helper
return functionalized_f_helper(primals, tangents)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1341, in functionalized_f_helper
f_outs = fn(*f_args)
^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1312, in inner_fn_with_anomaly
return inner_fn(*args)
^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1295, in inner_fn
backward_out = torch.autograd.grad(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/autograd/__init__.py", line 319, in grad
result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 555, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 580, in inner_torch_dispatch
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 361, in proxy_call
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_ops.py", line 437, in __call__
return self._op(*args, **kwargs or {})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: aten::_conj() Expected a value of type 'Tensor' for argument 'self' but instead found type 'complex'.
Position: 0
Value: 1j
Declaration: aten::_conj(Tensor(a) self) -> Tensor(a)
Cast error details: Unable to cast 1j to Tensor
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Versions
# packages in environment at /usr/local/envs/mdx-net:
#
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_gnu conda-forge
absl-py 1.4.0 pypi_0 pypi
antlr4-python3-runtime 4.9.3 pypi_0 pypi
attrs 23.1.0 pypi_0 pypi
bzip2 1.0.8 h7f98852_4 conda-forge
ca-certificates 2023.5.7 hbcca054_0 conda-forge
certifi 2023.5.7 pypi_0 pypi
cffi 1.15.1 pypi_0 pypi
charset-normalizer 3.2.0 pypi_0 pypi
click 8.1.5 pypi_0 pypi
cloudpickle 2.2.1 pypi_0 pypi
cmake 3.26.4 pypi_0 pypi
contextlib2 21.6.0 pypi_0 pypi
cython 0.29.36 pypi_0 pypi
demucs 4.0.0 pypi_0 pypi
diffq 0.2.4 pypi_0 pypi
docker-pycreds 0.4.0 pypi_0 pypi
dora-search 0.1.12 pypi_0 pypi
einops 0.6.1 pypi_0 pypi
ffmpeg-python 0.2.0 pypi_0 pypi
filelock 3.12.2 pypi_0 pypi
fsspec 2023.4.0 pypi_0 pypi
future 0.18.3 pypi_0 pypi
gitdb 4.0.10 pypi_0 pypi
gitpython 3.1.32 pypi_0 pypi
idna 3.4 pypi_0 pypi
jinja2 3.1.2 pypi_0 pypi
jsonschema 4.18.3 pypi_0 pypi
jsonschema-specifications 2023.6.1 pypi_0 pypi
julius 0.2.7 pypi_0 pypi
lameenc 1.5.1 pypi_0 pypi
ld_impl_linux-64 2.40 h41732ed_0 conda-forge
libexpat 2.5.0 hcb278e6_1 conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc-ng 13.1.0 he5830b7_0 conda-forge
libgomp 13.1.0 he5830b7_0 conda-forge
libnsl 2.0.0 h7f98852_0 conda-forge
libsqlite 3.42.0 h2797004_0 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libzlib 1.2.13 hd590300_5 conda-forge
lit 16.0.6 pypi_0 pypi
markupsafe 2.1.3 pypi_0 pypi
mir-eval 0.7 pypi_0 pypi
ml-collections 0.1.1 pypi_0 pypi
mpmath 1.3.0 pypi_0 pypi
musdb 0.4.0 pypi_0 pypi
museval 0.4.1 pypi_0 pypi
ncurses 6.4 hcb278e6_0 conda-forge
networkx 3.1 pypi_0 pypi
numpy 1.25.1 pypi_0 pypi
nvidia-cublas-cu11 11.10.3.66 pypi_0 pypi
nvidia-cuda-cupti-cu11 11.7.101 pypi_0 pypi
nvidia-cuda-nvrtc-cu11 11.7.99 pypi_0 pypi
nvidia-cuda-runtime-cu11 11.7.99 pypi_0 pypi
nvidia-cudnn-cu11 8.5.0.96 pypi_0 pypi
nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi
nvidia-curand-cu11 10.2.10.91 pypi_0 pypi
nvidia-cusolver-cu11 11.4.0.1 pypi_0 pypi
nvidia-cusparse-cu11 11.7.4.91 pypi_0 pypi
nvidia-nccl-cu11 2.14.3 pypi_0 pypi
nvidia-nvtx-cu11 11.7.91 pypi_0 pypi
omegaconf 2.3.0 pypi_0 pypi
openssl 3.1.1 hd590300_1 conda-forge
openunmix 1.2.1 pypi_0 pypi
pandas 2.0.3 pypi_0 pypi
pathtools 0.1.2 pypi_0 pypi
pip 23.2 pyhd8ed1ab_0 conda-forge
promise 2.3 pypi_0 pypi
protobuf 3.20.3 pypi_0 pypi
psutil 5.9.5 pypi_0 pypi
pyaml 23.7.0 pypi_0 pypi
pycparser 2.21 pypi_0 pypi
python 3.11.4 hab00c5b_0_cpython conda-forge
python-dateutil 2.8.2 pypi_0 pypi
pytorch-triton 2.1.0+3c400e7818 pypi_0 pypi
pytz 2023.3 pypi_0 pypi
pyyaml 6.0 pypi_0 pypi
readline 8.2 h8228510_1 conda-forge
referencing 0.29.1 pypi_0 pypi
requests 2.31.0 pypi_0 pypi
retrying 1.3.4 pypi_0 pypi
rpds-py 0.8.10 pypi_0 pypi
scipy 1.11.1 pypi_0 pypi
sentry-sdk 1.28.1 pypi_0 pypi
setproctitle 1.3.2 pypi_0 pypi
setuptools 68.0.0 pyhd8ed1ab_0 conda-forge
shortuuid 1.0.11 pypi_0 pypi
simplejson 3.19.1 pypi_0 pypi
six 1.16.0 pypi_0 pypi
smmap 5.0.0 pypi_0 pypi
soundfile 0.12.1 pypi_0 pypi
stempeg 0.2.3 pypi_0 pypi
submitit 1.4.5 pypi_0 pypi
sympy 1.12 pypi_0 pypi
tk 8.6.12 h27826a3_0 conda-forge
torch 2.1.0.dev20230716+cu118 pypi_0 pypi
torchaudio 2.0.2 pypi_0 pypi
tqdm 4.65.0 pypi_0 pypi
treetable 0.2.5 pypi_0 pypi
triton 2.0.0 pypi_0 pypi
typing-extensions 4.7.1 pypi_0 pypi
tzdata 2023.3 pypi_0 pypi
urllib3 2.0.3 pypi_0 pypi
wandb 0.13.2 pypi_0 pypi
wheel 0.40.0 pyhd8ed1ab_1 conda-forge
xz 5.2.6 h166bdaf_0 conda-forge