With both PyTorch 1.13 and Nightly, I am facing the same issue: loading and saving my model into ONNX works fine, but when running inference via ONNX Runtime, the output is 1. seemingly random when no weights are loaded, and 2. the exact same output regardless of the input I give it. Moreover, dynamic axes seem to not work, even though that is specified in the export settings.
Export
# Exporting my code like this
device = 'cuda'
tgt_in = torch.ones([1, 1], dtype=torch.long).to(device)
tgt_mask = torch.triu(torch.ones(1, 1) * float("-inf"), diagonal=1).to(device)
torch_model = Model_1(vocab_size=217, d_model=256, nhead=8, dim_FF=1024, dropout=0.3, num_layers=3).to(device)
#torch_model.load_state_dict(torch.load('runs/Exp8E8End_Acc=0.6702922582626343.pt', weights_only=True))
torch_model.eval()
torch_input = (torch.randn(1, 3, 384, 512).to(device),
tgt_in,
tgt_mask)
torch.onnx.export(
torch_model,
torch_input,
"model_1.onnx",
export_params = True,
input_names = ["image", "tgt", "tgt_mask"],
output_names = ["output1"],
dynamic_axes = {
"tgt": {1: "tgt_axis_1"},
"tgt_mask": {0: "tgt_mask_0", 1: "tgt_mask_1"},
"output1": {1: "output1"}
}
)
Inference check
# Checking inference like this
import onnxruntime as ort
import numpy as np
from PIL import Image
onnx_model_path = "model_1.onnx"
session = ort.InferenceSession(onnx_model_path)
def load_and_preprocess_image(image_path):
image = Image.open(image_path).convert('RGB')
image = image.resize((512, 384))
image_array = np.array(image).transpose(2, 0, 1)
image_array = np.expand_dims(image_array, axis=0)
return image_array
test_image = load_and_preprocess_image('test3.png').astype(np.float32)
test_tgt = np.ones((1, 2), dtype=np.int64)
test_mask = np.triu(np.ones((2, 2), dtype=np.float32) * float('-inf'), k=1)
inputs = {"image": test_image,
"tgt": test_tgt,
"tgt_mask": test_mask
}
outputs = session.run(None, inputs)
outputs = outputs[0][0][0]
max_val = max(outputs)
print(max_val)
print(np.where(outputs == max_val))
Model Architecture
class Model_1(nn.Module): # from https://actamachina.com/handwritten-mathematical-expression-recognition, CNN encoder and then transformer decoder
def __init__(self, vocab_size, d_model, nhead, dim_FF, dropout, num_layers):
super(Model_1, self).__init__()
densenet = densenet121(weights=DenseNet121_Weights.DEFAULT)
self.encoder = nn.Sequential(
nn.Sequential(*list(densenet.children())[:-1]), # remove the final layer, output (B, 1024, 12, 16)
nn.Conv2d(1024, d_model, kernel_size=1), # 1x1 convolution, output of (B, d_model, W, H) ex. (1, 256, 12, 16)
Permute(0, 3, 2, 1),
PosEncode2D(d_model=d_model, dropout_percent=dropout, max_len=150, PE_temp=10000), # output (1, 16, 12, 256)
nn.Flatten(1, 2)
) # removed .to(device) here
self.tgt_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model, padding_idx=0) # simple lookup table, input indices
self.word_PE = PosEncode1D(d_model, dropout, max_len=150, PE_temp=10000)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model, nhead, dim_FF, dropout, batch_first=True), # batch_first -> (batch, sequence, feature)
num_layers,
) # input target and memory (last sequence of the encoder), then tgt_mask, memory_mask
self.fc_out = nn.Linear(d_model, vocab_size) # y = xA^T + b, distribution over all tokens in vocabulary
self.d_model = d_model
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
init.xavier_uniform_(m.weight) # Glorot initialization for linear layers
if m.bias is not None:
init.zeros_(m.bias) # Bias initialized to zeros
elif isinstance(m, nn.Conv2d):
init.xavier_uniform_(m.weight) # Glorot initialization for conv layers
if m.bias is not None:
init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
init.xavier_uniform_(m.weight) # Glorot initialization for embedding layers
def decoder(self, features, tgt, tgt_mask):
padding_mask = tgt.eq(0) # checks where elements of tgt are equal to zero
tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model) # tgt indices become embedding vectors and are scaled by sqrt of model size for stability
tgt = self.word_PE(tgt) # adds positional encoding, size (B, seq_len, d_model)
tgt = self.transformer_decoder(tgt=tgt, memory=features, tgt_mask=tgt_mask, tgt_key_padding_mask=padding_mask, tgt_is_causal=True) # removed .to(torch.float32) for tgt_mask and padding_mask
output = self.fc_out(tgt) # size (B, seq_len, vocab_size
return output
def forward(self, src, tgt, tgt_mask):
features = self.encoder(src)
output = self.decoder(features, tgt, tgt_mask)
return output
# UTIL FUNCTIONS
class Permute(nn.Module):
def __init__(self, *dims: int): # asterisk accepts arbitary amount of arguments
super().__init__()
self.dims = dims
def forward(self, x):
return x.permute(*self.dims) # reorders the tuple
class PosEncode1D(nn.Module):
def __init__(self, d_model, dropout_percent, max_len, PE_temp):
super().__init__()
position = torch.arange(max_len).unsqueeze(1) # creates a vector (max_len x 1), 1 is needed for matmul operations
dim_t = torch.arange(0, d_model, 2) # 2i term in the denominator exponent
scaling = PE_temp **(dim_t/d_model) # entire denominator
pe = torch.zeros(max_len, d_model) #
pe[:, 0::2] = torch.sin(position / scaling) # every second term starting from 0 (even)
pe[:, 1::2] = torch.cos(position / scaling) # every second term starting from 1 (odd)
self.dropout = nn.Dropout(dropout_percent)
self.register_buffer("pe", pe) # stores pe tensor to be used but not updated
def forward(self, x):
batch, sequence_length, d_model = x.shape
return self.dropout(x + self.pe[None, :sequence_length, :]) # None to broadcast across batch, adds element-wise [x + pe, . . .]
class PosEncode2D(nn.Module):
def __init__(self, d_model, dropout_percent, max_len, PE_temp):
super().__init__()
# 1D encoding
position = torch.arange(max_len).unsqueeze(1)
dim_t = torch.arange(0, d_model, 2)
scaling = PE_temp **(dim_t/d_model)
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position / scaling)
pe[:, 1::2] = torch.cos(position / scaling)
pe_2D = torch.zeros(max_len, max_len, d_model) # some outer product magic
for i in range(d_model):
pe_2D[:, :, i] = pe[:, i].unsqueeze(1) + pe[:, i].unsqueeze(0) # first unsqueeze changed from -1
self.dropout = nn.Dropout(dropout_percent)
self.register_buffer("pe", pe_2D)
def forward(self, x):
batch, height, width, d_model = x.shape
return self.dropout(x + self.pe[None, :height, :width, :])
I have also tried using torch.onnx.dynamo_export, but that returns the same 'torch._dynamo.exc.Unsupported: ‘skip function MyModel.forward’ error even when I copy the simplest example from the docs.
Traceback (most recent call last):
File "c:\Users\edmun\Desktop\VSCode Projects\HME_Training\to_onnx.py", line 3, in <module>
from Models import Model_1, Model_1_Bool, DecoderOnly
File "c:\Users\edmun\Desktop\VSCode Projects\HME_Training\Models.py", line 8, in <module>
from Util import PatchEmbedding, Permute, PosEncode1D, PosEncode2D, ONNXFlatten
ImportError: cannot import name 'ONNXFlatten' from 'Util' (c:\Users\edmun\Desktop\VSCode Projects\HME_Training\Util.py)
PS C:\Users\edmun\Desktop\VSCode Projects\HME_Training> & C:/Users/edmun/anaconda3/envs/torchnightly/python.exe "c:/Users/edmun/Desktop/VSCode Projects/HME_Training/to_onnx.py"
C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\onnxscript\converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\onnxscript\converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\onnx\_internal\_exporter_legacy.py:108: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
warnings.warn(
Traceback (most recent call last):
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\onnx\_internal\_exporter_legacy.py", line 798, in dynamo_export
).export()
^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\onnx\_internal\_exporter_legacy.py", line 557, in export
graph_module = self.options.fx_tracer.generate_fx(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\onnx\_internal\fx\dynamo_graph_extractor.py", line 198, in generate_fx
graph_module, graph_guard = torch._dynamo.export(
^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\eval_frame.py", line 1539, in inner
result_traced = opt_f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\eval_frame.py", line 556, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 1428, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 550, in __call__
return _compile(
^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 979, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 709, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 744, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\bytecode_transformation.py", line 1348, in transform_code_object
transformations(instructions, code_options)
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 234, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 663, in transform
tracer.run()
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2914, in run
super().run()
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 1120, in run
while self.step():
^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 1032, in step
self.dispatch_table[inst.opcode](self, inst)
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 640, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 1816, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 967, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\variables\functions.py", line 751, in call_function
unimplemented(msg)
File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\exc.py", line 313, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: 'skip function MyModel.forward in file
If anyone has any insight on what is going on, or know if I’m doing something wrong, I would very, very appreciate the help. Thank you!