Pix2pix model training error with torch.compile()

Hello! I am having some trouble training a pix2pix model when using torch.compile().
I have looked into this but the suggested solution (compile the models and not the training function) does not solve the problem.
Here I report the code (a simplification) that generates the error:

import torch
import torch.nn as nn

class EncoderBlock(nn.Module):

    def __init__(self, in_size, out_size, kernel_size=4, padding=1, stride=2, batch_norm=True,
                 leaky_relu_slope=0.2, use_instance_norm=False):
        """
        Convolutional block
        :param in_size: input depth
        :param out_size: output depth
        :param kernel_size: kernel size
        :param padding: padding
        :param stride: stride
        :param batch_norm: whether to use batch normalization/instance normalization
        :param leaky_relu_slope: slope of leaky ReLU
        :param use_instance_norm: use instance normalization if batch size is 1
        """

        super().__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(in_size, out_size, kernel_size=kernel_size, padding=padding, stride=stride,
                      bias=False),
            nn.BatchNorm2d(out_size) if (batch_norm and not use_instance_norm) else None,
            nn.InstanceNorm2d(out_size, affine=True) if (
                    batch_norm and use_instance_norm) else None,
            nn.LeakyReLU(inplace=True, negative_slope=leaky_relu_slope))

        # remove None from conv_block
        self.conv_block = nn.Sequential(*[x for x in self.conv_block if x is not None])

        # initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Conv2d):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)):
            nn.init.normal_(module.weight.data, 1.0, 0.2)
            nn.init.constant_(module.bias.data, 0)

    def forward(self, x):
        y = self.conv_block(x)

        return y


class DecoderBlock(nn.Module):

    def __init__(self, in_size, out_size, kernel_size=4, padding=1, stride=2, batch_norm=True,
                 apply_dropout=True, dropout_p=0.5, use_instance_norm=False):
        """
        Convolutional block
        :param in_size: input depth
        :param out_size: output depth
        :param kernel_size: kernel size
        :param padding: padding
        :param stride: stride
        :param batch_norm: whether to use batch normalization
        :param apply_dropout: whether to apply dropout
        :param dropout_p: dropout probability
        :param use_instance_norm: use instance normalization if batch size is 1
        """

        super().__init__()


        self.decoder_block = nn.Sequential(
            nn.ConvTranspose2d(in_size, out_size, kernel_size=kernel_size, padding=padding,
                               stride=stride, bias=False),
            nn.BatchNorm2d(out_size) if (batch_norm and not use_instance_norm) else None,
            nn.InstanceNorm2d(out_size, affine=True) if (
                    batch_norm and use_instance_norm) else None,
            nn.Dropout2d(p=dropout_p) if apply_dropout else None,
            nn.ReLU(inplace=True))

        # remove None from decoder_block
        self.decoder_block = nn.Sequential(
            *[x for x in self.decoder_block if x is not None])


        # initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.ConvTranspose2d):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)):
            nn.init.normal_(module.weight.data, 1.0, 0.2)
            nn.init.constant_(module.bias.data, 0)

    def forward(self, x):
        y = self.decoder_block(x)

        return y


class Generator(nn.Module):

    def __init__(self, in_channels=1, out_channels=3, filters=(
            64, 128, 256, 512, 512, 512, 512, 512), use_instance_norm=False,
                 leaky_relu_slope=0.2):
        """
        Generator network
        :param in_channels: input channels, 1 for grayscale
        :param out_channels: output channels, 3 for RGB
        :param filters: number of filters in each layer
        :param use_instance_norm: use instance normalization if batch size is 1
        :param leaky_relu_slope: slope of leaky ReLU
        """
        super().__init__()

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.last = nn.ConvTranspose2d(2 * filters[0], out_channels, kernel_size=4, padding=1,
                                       stride=2)
        # initialize weights for last layer using xavier uniform
        nn.init.xavier_uniform_(self.last.weight)
        nn.init.constant_(self.last.bias, 0)

        self.final_activation = nn.Tanh()

        # Encoder
        for i, n_filters in enumerate(filters):
            if i == 0:
                self.encoder.append(EncoderBlock(in_channels, n_filters, batch_norm=False,
                                                 use_instance_norm=use_instance_norm,
                                                 leaky_relu_slope=leaky_relu_slope))
            # cannot use instance norm for last layer of encoder because spatial dimensions are 1
            elif i == len(filters) - 1 and use_instance_norm:
                self.encoder.append(EncoderBlock(filters[i - 1], n_filters, batch_norm=False,
                                                 use_instance_norm=use_instance_norm,
                                                 leaky_relu_slope=leaky_relu_slope))
            else:
                self.encoder.append(EncoderBlock(filters[i - 1], n_filters,
                                                 use_instance_norm=use_instance_norm,
                                                 leaky_relu_slope=leaky_relu_slope))

        # Decoder
        for i, n_filters in enumerate(reversed(filters[:-1])):

            if i == 0:
                self.decoder.append(DecoderBlock(filters[-i - 1], n_filters, apply_dropout=False,
                                                 use_instance_norm=use_instance_norm))
            elif 1 <= i <= 2:
                self.decoder.append(
                    DecoderBlock(2 * filters[-i - 1], n_filters, apply_dropout=False,
                                 use_instance_norm=use_instance_norm))
            else:
                self.decoder.append(DecoderBlock(2 * filters[-i - 1], n_filters,
                                                 use_instance_norm=use_instance_norm))

    def forward(self, x):

        # Create a list to store skip-connections
        skip = []

        # Encoder:
        for encoder_step in self.encoder:
            x = encoder_step(x)
            skip.append(x)

        skip = skip[:-1]  # remove last element

        # Decoder:
        for skip_connection, decoder_step in zip(reversed(skip), self.decoder):
            x = decoder_step(x)
            x = torch.cat((x, skip_connection), dim=1)

        # Last layer:
        x = self.last(x)
        x = self.final_activation(x)

        return x



class Discriminator(nn.Module):
    def __init__(self, leaky_relu_slope=0.2, in_channels=4, filters=(64, 128, 256, 512),
                 use_instance_norm=False, disable_norm=False, network_type='patch_gan'):
        """
        Discriminator network
        :param leaky_relu_slope: slope of leaky ReLU
        :param in_channels: input channels, 4 for RGB + grayscale
        :param filters: number of filters in each layer
        :param use_instance_norm: use instance normalization if batch size is 1
        :param disable_norm: disable normalization (for WGAN-GP)
        :param network_type: type of discriminator network, 'patch_gan' or 'DCGAN'
        """
        super(Discriminator, self).__init__()

        self.leaky_relu_slope = leaky_relu_slope
        self.in_channels = in_channels
        self.filters = filters
        assert network_type in ['patch_gan', 'DCGAN'], 'network_type must be either patch_gan or DCGAN'
        self.network_type = network_type

        self.model = nn.ModuleList()

        for i, n_filters in enumerate(filters):
            if i == 0:
                self.model.append(EncoderBlock(in_channels, n_filters, kernel_size=4,
                                               batch_norm=False,
                                               leaky_relu_slope=leaky_relu_slope,
                                               use_instance_norm=use_instance_norm))
            elif 1 <= i <= 2:
                self.model.append(EncoderBlock(filters[i - 1], n_filters, kernel_size=4,
                                               batch_norm=(not disable_norm),
                                               leaky_relu_slope=leaky_relu_slope,
                                               use_instance_norm=use_instance_norm))
            else:
                self.model.append(EncoderBlock(filters[i - 1], n_filters, kernel_size=4,
                                               batch_norm=(not disable_norm),
                                               stride=1, leaky_relu_slope=leaky_relu_slope,
                                               use_instance_norm=use_instance_norm))

        if self.network_type == 'patch_gan':
            last = nn.Conv2d(filters[-1], 1, kernel_size=4, stride=1, padding=1)
            # initialize weights for last layer using xavier uniform
            nn.init.xavier_uniform_(last.weight)
            nn.init.constant_(last.bias, 0)
            self.model.append(last)
        elif self.network_type == 'DCGAN':
            self.model.append(nn.AdaptiveAvgPool2d(1))
            self.model.append(nn.Flatten())
            self.model.append(nn.LazyLinear(1))

    def forward(self, x):
        for layer in self.model:
            x = layer(x)
        return x

# define generator
generator = Generator(use_instance_norm=False, in_channels=1, out_channels=2)
generator = torch.compile(generator)
generator.cuda()

# test forward pass
x = torch.randn(4, 1, 256, 256).cuda()
print(generator(x).shape)

# define discriminator
discriminator = Discriminator(use_instance_norm=False, in_channels=3, disable_norm=False, network_type='patch_gan')
discriminator = torch.compile(discriminator)
discriminator.cuda()

# test forward pass
x = torch.randn(4, 3, 256, 256).cuda()
print(discriminator(x).shape)

gen_opt = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.9))
disc_opt = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.9))

torch.autograd.set_detect_anomaly(True)
torch._dynamo.config.verbose=True

gray = torch.randn(4, 1, 256, 256).cuda() # L channel
color = torch.randn(4, 2, 256, 256).cuda() # ab channels

# train discriminator
disc_opt.zero_grad()

# train discriminator with fakes

fake_color = generator(gray)
fake_preds = discriminator(torch.cat((gray, fake_color.detach()), dim=1))
fake_targets = torch.zeros_like(fake_preds)
fake_loss = torch.nn.functional.binary_cross_entropy_with_logits(fake_preds, fake_targets)
fake_loss.backward()

# train discriminator with reals
real_preds = discriminator(torch.cat((gray, color), dim=1))
real_targets = torch.ones_like(real_preds)
real_loss = torch.nn.functional.binary_cross_entropy_with_logits(real_preds, real_targets)
real_loss.backward()

disc_opt.step()

# train generator
gen_opt.zero_grad()

fake_color = generator(gray)
fake_preds = discriminator(torch.cat((gray, fake_color), dim=1)) # error occurs here
fake_targets = torch.ones_like(fake_preds)
gen_loss = torch.nn.functional.binary_cross_entropy_with_logits(fake_preds, fake_targets)
gen_loss.backward()

gen_opt.step()

and the output when running:

torch.Size([4, 2, 256, 256])
torch.Size([4, 1, 30, 30])
/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/autograd/__init__.py:303: UserWarning: Error detected in ConvolutionBackward0. Traceback of forward call that caused the error:
  File "/mnt/c/Users/loren/Dropbox/Università/Vision and Cognitive Services/project/error.py", line 235, in forward                                                                                       
    x = layer(x)                                                                                                                                                                                          
 (Triggered internally at /opt/conda/conda-bld/pytorch_1682343967769/work/torch/csrc/autograd/python_anomaly_mode.cpp:114.)                                                                               
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass                                                                                                   
Traceback (most recent call last):                                                                                                          
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())                                                                               
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper     
    compiled_gm = compiler_fn(gm, example_inputs)                                                                                           
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__                     
    return compile_fx(model_, inputs_, config_patches=self.config)                                                                          
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx        
    return aot_autograd(
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2822, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2515, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1715, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2104, in aot_dispatch_autograd
    fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 714, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 443, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 652, in flatten_fn
    tree_out = root_fn(*tree_args)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 459, in wrapped
    out = f(*tensors)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1158, in traced_joint
    return functionalized_f_helper(primals, tangents)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1110, in functionalized_f_helper
    f_outs = flat_fn_no_input_mutations(fn, f_primals, f_tangents, meta, keep_input_mutations)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1078, in flat_fn_no_input_mutations
    outs = flat_fn_with_synthetic_bases_expanded(fn, primals, primals_after_cloning, maybe_tangents, meta, keep_input_mutations)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1050, in flat_fn_with_synthetic_bases_expanded
    outs = forward_or_joint(fn, primals_before_cloning, primals, maybe_tangents, meta, keep_input_mutations)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1019, in forward_or_joint
    backward_out = torch.autograd.grad(
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 269, in grad
    return handle_torch_function(
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/overrides.py", line 1534, in handle_torch_function
    result = mode.__torch_function__(public_api, types, args, kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_inductor/overrides.py", line 38, in __torch_function__
    return func(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 303, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 487, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 512, in inner_torch_dispatch
    out = proxy_call(self, func, args, kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 345, in proxy_call
    out = func(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_ops.py", line 287, in __call__
    return self._op(*args, **kwargs or {})
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 987, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1162, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 410, in local_scalar_dense
    raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/c/Users/loren/Dropbox/Università/Vision and Cognitive Services/project/error.py", line 288, in <module>
    fake_preds = discriminator(torch.cat((gray, fake_color), dim=1)) # error occurs here
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/lorenzo/miniconda3/envs/torch_wsl/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised DataDependentOutputException: aten._local_scalar_dense.default


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True