Hello! I am having some trouble training a pix2pix model when using torch.compile().
I have looked into this but the suggested solution (compile the models and not the training function) does not solve the problem.
Here I report the code (a simplification) that generates the error:
import torch
import torch.nn as nn
class EncoderBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=4, padding=1, stride=2, batch_norm=True,
leaky_relu_slope=0.2, use_instance_norm=False):
"""
Convolutional block
:param in_size: input depth
:param out_size: output depth
:param kernel_size: kernel size
:param padding: padding
:param stride: stride
:param batch_norm: whether to use batch normalization/instance normalization
:param leaky_relu_slope: slope of leaky ReLU
:param use_instance_norm: use instance normalization if batch size is 1
"""
super().__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_size, out_size, kernel_size=kernel_size, padding=padding, stride=stride,
bias=False),
nn.BatchNorm2d(out_size) if (batch_norm and not use_instance_norm) else None,
nn.InstanceNorm2d(out_size, affine=True) if (
batch_norm and use_instance_norm) else None,
nn.LeakyReLU(inplace=True, negative_slope=leaky_relu_slope))
# remove None from conv_block
self.conv_block = nn.Sequential(*[x for x in self.conv_block if x is not None])
# initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)):
nn.init.normal_(module.weight.data, 1.0, 0.2)
nn.init.constant_(module.bias.data, 0)
def forward(self, x):
y = self.conv_block(x)
return y
class DecoderBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=4, padding=1, stride=2, batch_norm=True,
apply_dropout=True, dropout_p=0.5, use_instance_norm=False):
"""
Convolutional block
:param in_size: input depth
:param out_size: output depth
:param kernel_size: kernel size
:param padding: padding
:param stride: stride
:param batch_norm: whether to use batch normalization
:param apply_dropout: whether to apply dropout
:param dropout_p: dropout probability
:param use_instance_norm: use instance normalization if batch size is 1
"""
super().__init__()
self.decoder_block = nn.Sequential(
nn.ConvTranspose2d(in_size, out_size, kernel_size=kernel_size, padding=padding,
stride=stride, bias=False),
nn.BatchNorm2d(out_size) if (batch_norm and not use_instance_norm) else None,
nn.InstanceNorm2d(out_size, affine=True) if (
batch_norm and use_instance_norm) else None,
nn.Dropout2d(p=dropout_p) if apply_dropout else None,
nn.ReLU(inplace=True))
# remove None from decoder_block
self.decoder_block = nn.Sequential(
*[x for x in self.decoder_block if x is not None])
# initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.ConvTranspose2d):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)):
nn.init.normal_(module.weight.data, 1.0, 0.2)
nn.init.constant_(module.bias.data, 0)
def forward(self, x):
y = self.decoder_block(x)
return y
class Generator(nn.Module):
def __init__(self, in_channels=1, out_channels=3, filters=(
64, 128, 256, 512, 512, 512, 512, 512), use_instance_norm=False,
leaky_relu_slope=0.2):
"""
Generator network
:param in_channels: input channels, 1 for grayscale
:param out_channels: output channels, 3 for RGB
:param filters: number of filters in each layer
:param use_instance_norm: use instance normalization if batch size is 1
:param leaky_relu_slope: slope of leaky ReLU
"""
super().__init__()
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.last = nn.ConvTranspose2d(2 * filters[0], out_channels, kernel_size=4, padding=1,
stride=2)
# initialize weights for last layer using xavier uniform
nn.init.xavier_uniform_(self.last.weight)
nn.init.constant_(self.last.bias, 0)
self.final_activation = nn.Tanh()
# Encoder
for i, n_filters in enumerate(filters):
if i == 0:
self.encoder.append(EncoderBlock(in_channels, n_filters, batch_norm=False,
use_instance_norm=use_instance_norm,
leaky_relu_slope=leaky_relu_slope))
# cannot use instance norm for last layer of encoder because spatial dimensions are 1
elif i == len(filters) - 1 and use_instance_norm:
self.encoder.append(EncoderBlock(filters[i - 1], n_filters, batch_norm=False,
use_instance_norm=use_instance_norm,
leaky_relu_slope=leaky_relu_slope))
else:
self.encoder.append(EncoderBlock(filters[i - 1], n_filters,
use_instance_norm=use_instance_norm,
leaky_relu_slope=leaky_relu_slope))
# Decoder
for i, n_filters in enumerate(reversed(filters[:-1])):
if i == 0:
self.decoder.append(DecoderBlock(filters[-i - 1], n_filters, apply_dropout=False,
use_instance_norm=use_instance_norm))
elif 1 <= i <= 2:
self.decoder.append(
DecoderBlock(2 * filters[-i - 1], n_filters, apply_dropout=False,
use_instance_norm=use_instance_norm))
else:
self.decoder.append(DecoderBlock(2 * filters[-i - 1], n_filters,
use_instance_norm=use_instance_norm))
def forward(self, x):
# Create a list to store skip-connections
skip = []
# Encoder:
for encoder_step in self.encoder:
x = encoder_step(x)
skip.append(x)
skip = skip[:-1] # remove last element
# Decoder:
for skip_connection, decoder_step in zip(reversed(skip), self.decoder):
x = decoder_step(x)
x = torch.cat((x, skip_connection), dim=1)
# Last layer:
x = self.last(x)
x = self.final_activation(x)
return x
class Discriminator(nn.Module):
def __init__(self, leaky_relu_slope=0.2, in_channels=4, filters=(64, 128, 256, 512),
use_instance_norm=False, disable_norm=False, network_type='patch_gan'):
"""
Discriminator network
:param leaky_relu_slope: slope of leaky ReLU
:param in_channels: input channels, 4 for RGB + grayscale
:param filters: number of filters in each layer
:param use_instance_norm: use instance normalization if batch size is 1
:param disable_norm: disable normalization (for WGAN-GP)
:param network_type: type of discriminator network, 'patch_gan' or 'DCGAN'
"""
super(Discriminator, self).__init__()
self.leaky_relu_slope = leaky_relu_slope
self.in_channels = in_channels
self.filters = filters
assert network_type in ['patch_gan', 'DCGAN'], 'network_type must be either patch_gan or DCGAN'
self.network_type = network_type
self.model = nn.ModuleList()
for i, n_filters in enumerate(filters):
if i == 0:
self.model.append(EncoderBlock(in_channels, n_filters, kernel_size=4,
batch_norm=False,
leaky_relu_slope=leaky_relu_slope,
use_instance_norm=use_instance_norm))
elif 1 <= i <= 2:
self.model.append(EncoderBlock(filters[i - 1], n_filters, kernel_size=4,
batch_norm=(not disable_norm),
leaky_relu_slope=leaky_relu_slope,
use_instance_norm=use_instance_norm))
else:
self.model.append(EncoderBlock(filters[i - 1], n_filters, kernel_size=4,
batch_norm=(not disable_norm),
stride=1, leaky_relu_slope=leaky_relu_slope,
use_instance_norm=use_instance_norm))
if self.network_type == 'patch_gan':
last = nn.Conv2d(filters[-1], 1, kernel_size=4, stride=1, padding=1)
# initialize weights for last layer using xavier uniform
nn.init.xavier_uniform_(last.weight)
nn.init.constant_(last.bias, 0)
self.model.append(last)
elif self.network_type == 'DCGAN':
self.model.append(nn.AdaptiveAvgPool2d(1))
self.model.append(nn.Flatten())
self.model.append(nn.LazyLinear(1))
def forward(self, x):
for layer in self.model:
x = layer(x)
return x
# define generator
generator = Generator(use_instance_norm=False, in_channels=1, out_channels=2)
generator = torch.compile(generator)
generator.cuda()
# test forward pass
x = torch.randn(4, 1, 256, 256).cuda()
print(generator(x).shape)
# define discriminator
discriminator = Discriminator(use_instance_norm=False, in_channels=3, disable_norm=False, network_type='patch_gan')
discriminator = torch.compile(discriminator)
discriminator.cuda()
# test forward pass
x = torch.randn(4, 3, 256, 256).cuda()
print(discriminator(x).shape)
gen_opt = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.9))
disc_opt = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.9))
torch.autograd.set_detect_anomaly(True)
torch._dynamo.config.verbose=True
gray = torch.randn(4, 1, 256, 256).cuda() # L channel
color = torch.randn(4, 2, 256, 256).cuda() # ab channels
# train discriminator
disc_opt.zero_grad()
# train discriminator with fakes
fake_color = generator(gray)
fake_preds = discriminator(torch.cat((gray, fake_color.detach()), dim=1))
fake_targets = torch.zeros_like(fake_preds)
fake_loss = torch.nn.functional.binary_cross_entropy_with_logits(fake_preds, fake_targets)
fake_loss.backward()
# train discriminator with reals
real_preds = discriminator(torch.cat((gray, color), dim=1))
real_targets = torch.ones_like(real_preds)
real_loss = torch.nn.functional.binary_cross_entropy_with_logits(real_preds, real_targets)
real_loss.backward()
disc_opt.step()
# train generator
gen_opt.zero_grad()
fake_color = generator(gray)
fake_preds = discriminator(torch.cat((gray, fake_color), dim=1)) # error occurs here
fake_targets = torch.ones_like(fake_preds)
gen_loss = torch.nn.functional.binary_cross_entropy_with_logits(fake_preds, fake_targets)
gen_loss.backward()
gen_opt.step()
and the output when running:
torch.Size([4, 2, 256, 256])
torch.Size([4, 1, 30, 30])
/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/autograd/__init__.py:303: UserWarning: Error detected in ConvolutionBackward0. Traceback of forward call that caused the error:
File "/mnt/c/Users/loren/Dropbox/Università/Vision and Cognitive Services/project/error.py", line 235, in forward
x = layer(x)
(Triggered internally at /opt/conda/conda-bld/pytorch_1682343967769/work/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
return aot_autograd(
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2822, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2515, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1715, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2104, in aot_dispatch_autograd
fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 714, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 443, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
(self.create_arg(fn(*args)),),
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 652, in flatten_fn
tree_out = root_fn(*tree_args)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 459, in wrapped
out = f(*tensors)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1158, in traced_joint
return functionalized_f_helper(primals, tangents)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1110, in functionalized_f_helper
f_outs = flat_fn_no_input_mutations(fn, f_primals, f_tangents, meta, keep_input_mutations)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1078, in flat_fn_no_input_mutations
outs = flat_fn_with_synthetic_bases_expanded(fn, primals, primals_after_cloning, maybe_tangents, meta, keep_input_mutations)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1050, in flat_fn_with_synthetic_bases_expanded
outs = forward_or_joint(fn, primals_before_cloning, primals, maybe_tangents, meta, keep_input_mutations)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1019, in forward_or_joint
backward_out = torch.autograd.grad(
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 269, in grad
return handle_torch_function(
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/overrides.py", line 1534, in handle_torch_function
result = mode.__torch_function__(public_api, types, args, kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_inductor/overrides.py", line 38, in __torch_function__
return func(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 303, in grad
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 487, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 512, in inner_torch_dispatch
out = proxy_call(self, func, args, kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 345, in proxy_call
out = func(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_ops.py", line 287, in __call__
return self._op(*args, **kwargs or {})
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 987, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1162, in dispatch
op_impl_out = op_impl(self, func, *args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 410, in local_scalar_dense
raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/mnt/c/Users/loren/Dropbox/Università/Vision and Cognitive Services/project/error.py", line 288, in <module>
fake_preds = discriminator(torch.cat((gray, fake_color), dim=1)) # error occurs here
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
return callback(frame, cache_size, hooks)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
return fn(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
return _compile(
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
self.output.compile_subgraph(
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised DataDependentOutputException: aten._local_scalar_dense.default
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True