ONNX model produces same output regardless of input

With both PyTorch 1.13 and Nightly, I am facing the same issue: loading and saving my model into ONNX works fine, but when running inference via ONNX Runtime, the output is 1. seemingly random when no weights are loaded, and 2. the exact same output regardless of the input I give it. Moreover, dynamic axes seem to not work, even though that is specified in the export settings.

Export

# Exporting my code like this
device = 'cuda'
tgt_in = torch.ones([1, 1], dtype=torch.long).to(device) 
tgt_mask = torch.triu(torch.ones(1, 1) * float("-inf"), diagonal=1).to(device)

torch_model = Model_1(vocab_size=217, d_model=256, nhead=8, dim_FF=1024, dropout=0.3, num_layers=3).to(device)
#torch_model.load_state_dict(torch.load('runs/Exp8E8End_Acc=0.6702922582626343.pt', weights_only=True))
torch_model.eval()

torch_input = (torch.randn(1, 3, 384, 512).to(device), 
               tgt_in, 
               tgt_mask)

torch.onnx.export(
    torch_model, 
    torch_input,
    "model_1.onnx",
    export_params = True,
    input_names = ["image", "tgt", "tgt_mask"],
    output_names = ["output1"],
    dynamic_axes = {
        "tgt": {1: "tgt_axis_1"},
        "tgt_mask": {0: "tgt_mask_0", 1: "tgt_mask_1"},
        "output1": {1: "output1"}
    }
)

Inference check

# Checking inference like this
import onnxruntime as ort
import numpy as np
from PIL import Image

onnx_model_path = "model_1.onnx"
session = ort.InferenceSession(onnx_model_path)

def load_and_preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = image.resize((512, 384))
    image_array = np.array(image).transpose(2, 0, 1)
    image_array = np.expand_dims(image_array, axis=0)
    return image_array

test_image = load_and_preprocess_image('test3.png').astype(np.float32)
test_tgt = np.ones((1, 2), dtype=np.int64)
test_mask = np.triu(np.ones((2, 2), dtype=np.float32) * float('-inf'), k=1)

inputs = {"image": test_image, 
          "tgt": test_tgt, 
          "tgt_mask": test_mask
          }

outputs = session.run(None, inputs)
outputs = outputs[0][0][0]
max_val = max(outputs)
print(max_val)
print(np.where(outputs == max_val))

Model Architecture

class Model_1(nn.Module): # from https://actamachina.com/handwritten-mathematical-expression-recognition, CNN encoder and then transformer decoder
    def __init__(self, vocab_size, d_model, nhead, dim_FF, dropout, num_layers):
        super(Model_1, self).__init__()
        densenet = densenet121(weights=DenseNet121_Weights.DEFAULT) 
        
        self.encoder = nn.Sequential(
            nn.Sequential(*list(densenet.children())[:-1]), # remove the final layer, output (B, 1024, 12, 16)
            nn.Conv2d(1024, d_model, kernel_size=1), # 1x1 convolution, output of (B, d_model, W, H) ex. (1, 256, 12, 16)
            Permute(0, 3, 2, 1), 
            PosEncode2D(d_model=d_model, dropout_percent=dropout, max_len=150, PE_temp=10000), # output (1, 16, 12, 256)
            nn.Flatten(1, 2)
        ) # removed .to(device) here
    
        self.tgt_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model, padding_idx=0) # simple lookup table, input indices 
        self.word_PE = PosEncode1D(d_model, dropout, max_len=150, PE_temp=10000)
        self.transformer_decoder = nn.TransformerDecoder( 
            nn.TransformerDecoderLayer(d_model, nhead, dim_FF, dropout, batch_first=True), # batch_first -> (batch, sequence, feature)
            num_layers,
        ) # input target and memory (last sequence of the encoder), then tgt_mask, memory_mask
        self.fc_out = nn.Linear(d_model, vocab_size) # y = xA^T + b, distribution over all tokens in vocabulary
        self.d_model = d_model

        self._initialize_weights()  

    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_uniform_(m.weight)  # Glorot initialization for linear layers
                if m.bias is not None:
                    init.zeros_(m.bias)  # Bias initialized to zeros
            elif isinstance(m, nn.Conv2d):
                init.xavier_uniform_(m.weight)  # Glorot initialization for conv layers
                if m.bias is not None:
                    init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                init.xavier_uniform_(m.weight)  # Glorot initialization for embedding layers

    def decoder(self, features, tgt, tgt_mask):
        padding_mask = tgt.eq(0) # checks where elements of tgt are equal to zero
        tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model) # tgt indices become embedding vectors and are scaled by sqrt of model size for stability
        tgt = self.word_PE(tgt) # adds positional encoding, size (B, seq_len, d_model)
        tgt = self.transformer_decoder(tgt=tgt, memory=features, tgt_mask=tgt_mask, tgt_key_padding_mask=padding_mask, tgt_is_causal=True) # removed .to(torch.float32) for tgt_mask and padding_mask
        output = self.fc_out(tgt) # size (B, seq_len, vocab_size
        return output
    
    def forward(self, src, tgt, tgt_mask):
        features = self.encoder(src)
        output = self.decoder(features, tgt, tgt_mask)
        return output

# UTIL FUNCTIONS

class Permute(nn.Module):
    def __init__(self, *dims: int): # asterisk accepts arbitary amount of arguments
        super().__init__()
        self.dims = dims

    def forward(self, x):
        return x.permute(*self.dims) # reorders the tuple
    
class PosEncode1D(nn.Module):
    def __init__(self, d_model, dropout_percent, max_len, PE_temp):
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1) # creates a vector (max_len x 1), 1 is needed for matmul operations
        dim_t = torch.arange(0, d_model, 2) # 2i term in the denominator exponent
        scaling = PE_temp **(dim_t/d_model) # entire denominator 

        pe = torch.zeros(max_len, d_model) # 
        pe[:, 0::2] = torch.sin(position / scaling) # every second term starting from 0 (even)
        pe[:, 1::2] = torch.cos(position / scaling) # every second term starting from 1 (odd)

        self.dropout = nn.Dropout(dropout_percent)
        self.register_buffer("pe", pe) # stores pe tensor to be used but not updated

    def forward(self, x):
        batch, sequence_length, d_model = x.shape         
        return self.dropout(x + self.pe[None, :sequence_length, :]) # None to broadcast across batch, adds element-wise [x + pe, . . .]


class PosEncode2D(nn.Module):
    def __init__(self, d_model, dropout_percent, max_len, PE_temp):
        super().__init__()
        # 1D encoding
        position = torch.arange(max_len).unsqueeze(1) 
        dim_t = torch.arange(0, d_model, 2)
        scaling = PE_temp **(dim_t/d_model)
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position / scaling)
        pe[:, 1::2] = torch.cos(position / scaling)

        pe_2D = torch.zeros(max_len, max_len, d_model) # some outer product magic
        for i in range(d_model):
            pe_2D[:, :, i] = pe[:, i].unsqueeze(1) + pe[:, i].unsqueeze(0)  # first unsqueeze changed from -1

        self.dropout = nn.Dropout(dropout_percent)
        self.register_buffer("pe", pe_2D) 

    def forward(self, x):
        batch, height, width, d_model = x.shape
        return self.dropout(x + self.pe[None, :height, :width, :])

I have also tried using torch.onnx.dynamo_export, but that returns the same 'torch._dynamo.exc.Unsupported: ‘skip function MyModel.forward’ error even when I copy the simplest example from the docs.

Traceback (most recent call last):
  File "c:\Users\edmun\Desktop\VSCode Projects\HME_Training\to_onnx.py", line 3, in <module>
    from Models import Model_1, Model_1_Bool, DecoderOnly
  File "c:\Users\edmun\Desktop\VSCode Projects\HME_Training\Models.py", line 8, in <module>
    from Util import PatchEmbedding, Permute, PosEncode1D, PosEncode2D, ONNXFlatten
ImportError: cannot import name 'ONNXFlatten' from 'Util' (c:\Users\edmun\Desktop\VSCode Projects\HME_Training\Util.py)
PS C:\Users\edmun\Desktop\VSCode Projects\HME_Training> & C:/Users/edmun/anaconda3/envs/torchnightly/python.exe "c:/Users/edmun/Desktop/VSCode Projects/HME_Training/to_onnx.py"
C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\onnxscript\converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\onnxscript\converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\onnx\_internal\_exporter_legacy.py:108: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
Traceback (most recent call last):
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\onnx\_internal\_exporter_legacy.py", line 798, in dynamo_export
    ).export()
      ^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\onnx\_internal\_exporter_legacy.py", line 557, in export
    graph_module = self.options.fx_tracer.generate_fx(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\onnx\_internal\fx\dynamo_graph_extractor.py", line 198, in generate_fx
    graph_module, graph_guard = torch._dynamo.export(
                                ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\eval_frame.py", line 1539, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\eval_frame.py", line 556, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 1428, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 550, in __call__
    return _compile(
           ^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 979, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 709, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 744, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\bytecode_transformation.py", line 1348, in transform_code_object
    transformations(instructions, code_options)
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 234, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\convert_frame.py", line 663, in transform
    tracer.run()
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2914, in run
    super().run()
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 1120, in run
    while self.step():
          ^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 1032, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 640, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 1816, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 967, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\variables\functions.py", line 751, in call_function
    unimplemented(msg)
  File "C:\Users\edmun\anaconda3\envs\torchnightly\Lib\site-packages\torch\_dynamo\exc.py", line 313, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: 'skip function MyModel.forward in file

If anyone has any insight on what is going on, or know if I’m doing something wrong, I would very, very appreciate the help. Thank you!