Models exported with torch.onnx.dynamo_export and torch.onnx.export show significant performance differences during inference.
code:
import torch
import torch.nn as nn
import torch.onnx
import onnxruntime
import numpy as np
import time
# Define the DNN model
class DNNModel(nn.Module):
def __init__(self):
super(DNNModel, self).__init__()
self.fc = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
def forward(self, x):
return self.fc(x)
# Define the CNN model
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
batch_size = 128
# Dummy input to match the model's input size
dnn_input = torch.randn(batch_size, 512)
cnn_input = torch.randn(batch_size, 1, 28, 28) # Image size is 28x28
# Instantiate models
dnn_model = DNNModel()
cnn_model = CNNModel()
# Export models using torch.onnx.export
torch.onnx.export(dnn_model, dnn_input, 'dnn_model.onnx')
torch.onnx.export(cnn_model, cnn_input, 'cnn_model.onnx')
# Export models using torch.onnx.dynamo_export
torch.onnx.dynamo_export(dnn_model, dnn_input).save('dnn_model.dynamo.onnx')
torch.onnx.dynamo_export(cnn_model, cnn_input).save('cnn_model.dynamo.onnx')
# Function to test ONNX model inference performance
def test_onnx_model(model_path, input_data):
session = onnxruntime.InferenceSession(model_path, provider_options=['CPUExecutionProvider'])
input_name = session.get_inputs()[0].name
timings = []
for _ in range(1000):
start_time = time.time()
session.run(None, {input_name: input_data.numpy()})
end_time = time.time()
timings.append(end_time - start_time)
return (np.mean(timings) * 1000).round(4)
# Test each model
dnn_input = np.random.randn(batch_size, 512).astype(np.float32)
cnn_input = np.random.randn(batch_size, 1, 28, 28).astype(np.float32)
print("#### ~~~ torch.onnx.export")
print(f"DNN Average Inference Time: {test_onnx_model('dnn_model.onnx', torch.from_numpy(dnn_input))} ms/batch")
print(f"CNN Average Inference Time: {test_onnx_model('cnn_model.onnx', torch.from_numpy(cnn_input))} ms/batch")
print("#### ~~~ torch.onnx.dynamo_export")
print(f"DNN Average Inference Time: {test_onnx_model('dnn_model.dynamo.onnx', torch.from_numpy(dnn_input))} ms/batch")
print(f"CNN Average Inference Time: {test_onnx_model('cnn_model.dynamo.onnx', torch.from_numpy(cnn_input))} ms/batch")
output:
#### ~~~ torch.onnx.export
DNN Average Inference Time: 0.156 ms/batch
CNN Average Inference Time: 2.2168 ms/batch
#### ~~~ torch.onnx.dynamo_export
DNN Average Inference Time: 0.2092 ms/batch
CNN Average Inference Time: 4.8949 ms/batch
pytorch: 2.2