Problem with mixing torch.jit.script and torch.jit.trace

Hello, I am using torch.jit.script to export my models but I encountered a problem with Resnet. This is my model definition:

class Model(nn.Module):
    def __init__(self, num_cats):
        super(Model, self).__init__()    
        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_cats)
        self.model = torch.jit.trace(self.model, torch.rand(1,3,224,224))
    def forward(self, x):
        x = self.model(x)
        return x

Training works fine and validation accuracy goes as expected. Finally I export my model like this

scripted = torch.jit.script(model)
torch.jit.save(scripted, 'scripted_model.pth')

The thing is that when I load my model

model = torch.jit.load('scripted_model.pth')

and use it to perform inference on the entire validation set with the same batch size than the one used for training I get the same predictions. But if I try to perform inference on a single image (or in the validation set with a small batch size), I get totally wrong predictions.

I’ve tried the same example with a custom CNN and I get good results. Also I tried exporting the resnet model with torch.jit.trace and torch.jit.trace_module (instead of torch.jit.script) and also get good results.

I would appreciate any help on the issue, since I find the functionality of torch.jit.script very powerful in production.

Thanks for the minimal repro! This sounds like a bug, would you mind filing an issue on GitHub and adding some info about your environment so we can track it better?

From a first look I’m not getting any difference between eager mode, the original scripted model, and the loaded scripted model (script here). Since self.model here is a traced graph the compilation should also be very simple since all it would really compile is the return self.model(x) statement, so I’m not sure why the results would be different.

Hi, you are right there is no difference between the models. The thing is that they all are wrong. Consider the following models:

class Model1(nn.Module):
    def __init__(self, num_cats):
        super(Model1, self).__init__()    
        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_cats)
        self.model = torch.jit.trace(self.model, torch.rand(1,3,224,224))
    def forward(self, x):
        x = self.model(x)
        return x

class Model2(nn.Module):
    def __init__(self, num_cats):
        super(Model2, self).__init__()    
        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_cats)
    def forward(self, x):
        x = self.model(x)
        return x

I can train both models with the same dataset and hyperparameters and achieve similar expected results (good results). I can then export them as follows:

scripted1 = torch.jit.script(model1)
torch.jit.save(scripted1, 'scripted_model1.pth')

scripted2 = torch.jit.trace(model2)
torch.jit.save(scripted2, 'scripted_model2.pth')

And load them

loaded1 = torch.jit.load('scripted_model1.pth')
loaded2 = torch.jit.load('scripted_model2.pth')

In all cases the 3 versions of each model give the same results, the difference is that using the first version only works fine when performing inference with a large batch size (and failing in single image inference). The second model, however, gives good predictions in all cases.

PS: I tried the same experiment than the one proposed here but with a custom CNN in place of resnet18 and I observe the same problem. Hence it seems to be something related with the interaction between torch.jit.script and torch.jit.trace.