Get different results after JIT compilation optimization

I have a simple model file as follows:

import torch
import torch.nn as nn
import torch.fx as fx

class CustomTracer(fx.Tracer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def is_leaf_module(self, module, name):
        return isinstance(module, CustomModule)

class CustomModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        conv = nn.Conv2d(3, 64, 3, padding=1)
        return conv(x)

def utility_function(x):
    return torch.relu(x)

class UtilityClass:
    def __init__(self):
        pass

    def method(self, x):
        bn = nn.BatchNorm2d(x.size(1))
        return bn(x)

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.custom_module = CustomModule()
        self.utility_class = UtilityClass()

    def forward(self, x):
        x = self.custom_module(x)
        x = utility_function(x)
        x = self.utility_class.method(x)
        return x

def my_model_function():
    return MyModel()

I get my model like this:model_function = getattr(model_module, "my_model_function").

I use two methods to derive the results of the model:

  • output_original = model(input)
  • python jit_model = torch.jit.trace(model, input) output_jit = jit_model(input)

and get this output:

Mismatched elements: 3211246 / 3211264 (100.0%)
Greatest absolute difference: 60.796468168497086 at index (0, 22, 5, 223) (up to 1e-05 allowed)
Greatest relative difference: 19350106.104020353 at index (0, 16, 223, 168) (up to 1e-05 allowed)

Why this happened?

Your model does not even produce the same outputs in Eager mode:

model = my_model_function()
x = torch.randn(1, 3, 24, 24)
out1 = model(x)
out2 = model(x)

torch.testing.assert_allclose(out1, out2)
# Mismatched elements: 36863 / 36864 (100.0%)
# Greatest absolute difference: 7.237934112548828 at index (0, 11, 11, 19) (up to 1e-05 allowed)
# Greatest relative difference: 49341.4609375 at index (0, 50, 23, 10) (up to 0.0001 allowed)

so I would probably focus on this first before digging into the deprecated jit.trace util.

The differences are most likely caused by the module creation in the forward pass, so move them into the __init__ of the parent class.

1 Like

Ah, I see!! Thanks for your help.

And, hi, sorry to bother you! I’m just curious — may I ask why you said torch.jit.trace is “deprecated” or not recommended? I can still find it in the official PyTorch 2.6 documentation, and it seems to be supported. Would you mind sharing any official notes, issues, or discussions that mention its deprecation or reduced support? I’d really appreciate it. Thank you!

You can find a few comments on GitHub and e.g this one from June 2023 explains the feature development is already “frozen”.

1 Like