I have a simple model file as follows:
import torch
import torch.nn as nn
import torch.fx as fx
class CustomTracer(fx.Tracer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def is_leaf_module(self, module, name):
return isinstance(module, CustomModule)
class CustomModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
conv = nn.Conv2d(3, 64, 3, padding=1)
return conv(x)
def utility_function(x):
return torch.relu(x)
class UtilityClass:
def __init__(self):
pass
def method(self, x):
bn = nn.BatchNorm2d(x.size(1))
return bn(x)
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.custom_module = CustomModule()
self.utility_class = UtilityClass()
def forward(self, x):
x = self.custom_module(x)
x = utility_function(x)
x = self.utility_class.method(x)
return x
def my_model_function():
return MyModel()
I get my model like this:model_function = getattr(model_module, "my_model_function")
.
I use two methods to derive the results of the model:
output_original = model(input)
python jit_model = torch.jit.trace(model, input) output_jit = jit_model(input)
and get this output:
Mismatched elements: 3211246 / 3211264 (100.0%)
Greatest absolute difference: 60.796468168497086 at index (0, 22, 5, 223) (up to 1e-05 allowed)
Greatest relative difference: 19350106.104020353 at index (0, 16, 223, 168) (up to 1e-05 allowed)
Why this happened?