[dynamo] [onnx] Models exported with torch.onnx.dynamo_export show worse performance during inference

Models exported with torch.onnx.dynamo_export and torch.onnx.export show significant performance differences during inference.

code:

import torch
import torch.nn as nn
import torch.onnx
import onnxruntime
import numpy as np
import time

# Define the DNN model
class DNNModel(nn.Module):
    def __init__(self):
        super(DNNModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.fc(x)

# Define the CNN model
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

batch_size = 128

# Dummy input to match the model's input size
dnn_input = torch.randn(batch_size, 512)
cnn_input = torch.randn(batch_size, 1, 28, 28)  # Image size is 28x28

# Instantiate models
dnn_model = DNNModel()
cnn_model = CNNModel()

# Export models using torch.onnx.export
torch.onnx.export(dnn_model, dnn_input, 'dnn_model.onnx')
torch.onnx.export(cnn_model, cnn_input, 'cnn_model.onnx')

# Export models using torch.onnx.dynamo_export
torch.onnx.dynamo_export(dnn_model, dnn_input).save('dnn_model.dynamo.onnx')
torch.onnx.dynamo_export(cnn_model, cnn_input).save('cnn_model.dynamo.onnx')

# Function to test ONNX model inference performance
def test_onnx_model(model_path, input_data):
    session = onnxruntime.InferenceSession(model_path, provider_options=['CPUExecutionProvider'])
    input_name = session.get_inputs()[0].name
    timings = []

    for _ in range(1000):
        start_time = time.time()
        session.run(None, {input_name: input_data.numpy()})
        end_time = time.time()
        timings.append(end_time - start_time)

    return (np.mean(timings) * 1000).round(4)

# Test each model
dnn_input = np.random.randn(batch_size, 512).astype(np.float32)
cnn_input = np.random.randn(batch_size, 1, 28, 28).astype(np.float32)

print("#### ~~~ torch.onnx.export")
print(f"DNN Average Inference Time: {test_onnx_model('dnn_model.onnx', torch.from_numpy(dnn_input))} ms/batch")
print(f"CNN Average Inference Time: {test_onnx_model('cnn_model.onnx', torch.from_numpy(cnn_input))} ms/batch")

print("#### ~~~ torch.onnx.dynamo_export")
print(f"DNN Average Inference Time: {test_onnx_model('dnn_model.dynamo.onnx', torch.from_numpy(dnn_input))} ms/batch")
print(f"CNN Average Inference Time: {test_onnx_model('cnn_model.dynamo.onnx', torch.from_numpy(cnn_input))} ms/batch")

output:

#### ~~~ torch.onnx.export
DNN Average Inference Time: 0.156 ms/batch
CNN Average Inference Time: 2.2168 ms/batch
#### ~~~ torch.onnx.dynamo_export
DNN Average Inference Time: 0.2092 ms/batch
CNN Average Inference Time: 4.8949 ms/batch

pytorch: 2.2