Model doesn't work with dynamic input shapes after exporting to onnx

Hi all, I created a model to classify videos of variant lengths. The procedure is described as (code provided below):

  1. Embed video frames to vectors separately with a pretrained ResNet34
  2. Reconstruct a sequence from these frame embeddings
  3. Produce a vector from the sequence with a transformer
  4. Pass through fully connected layers as classifier

The original input shape before preprocessing is (batch_size, num_frames, height, width, num_channels)=(batch_size, num_frames, 360, 640, 3) , the input for model is (batch_size, num_frames, 3, 224, 224) and the output is of shape (batch_size, 4), representing logits of 4 classes of the batch_size videos.

I tested the PyTorch model with various num_frames and it all worked, but after I exported it to onnx, the onnx model doesn’t work with other values of num_frames. The batch_size will always be 2 during inference

Here is the code for the model:
P.S. There are 2 transformers in the model, because there are 2 action types for the videos which is known while inferring and for each action type, there are 2 classes to classify. (Imagine there are videos of tennis and table tennis, and for each video of tennis/table tennis it shall classify if it is a successful hit)

class CustomTransformerClassifier(nn.Module):
    def __init__(self, embed_size, num_heads, num_layers, num_classes):
        super(CustomTransformerClassifier, self).__init__()
        self.transformer_layer = TransformerEncoderLayer(d_model=embed_size, nhead=num_heads)
        self.transformer_encoder = TransformerEncoder(self.transformer_layer, num_layers=num_layers)
        self.classifier = nn.Linear(embed_size, num_classes)

    def forward(self, inputs_embeds, attention_mask):
        # Permute (batch_size, sequence_length, embedding_size) to (sequence_length, batch_size, embedding_size)
        inputs_embeds = inputs_embeds.permute(1, 0, 2)
        
        # Applying transformer layers
        if attention_mask is not None:
            attention_mask = attention_mask.to(bool)
        transformer_out = self.transformer_encoder(inputs_embeds, src_key_padding_mask=attention_mask)
        
        # Pooling (e.g., taking the first token or mean pooling)
        pooled_output = transformer_out.mean(dim=0)

        # Classification
        logits = self.classifier(pooled_output)
        return logits

class VideoClassifier(nn.Module):
    def __init__(self):
        super(VideoClassifier, self).__init__()
        # Frame embedding module using a pre-trained CNN
        self.frame_embedding = models.resnet34(pretrained=True)
        # Modify the frame embedding module to output embeddings instead of class predictions
        self.frame_embedding = nn.Sequential(*list(self.frame_embedding.children())[:-1])
        
        # Transformer model for processing sequences of embeddings
        self.action_type_a_transformer = CustomTransformerClassifier(embed_size=512, num_heads=8, num_layers=3, num_classes=2)
        self.action_type_b_transformer = CustomTransformerClassifier(embed_size=512, num_heads=8, num_layers=3, num_classes=2)

    def forward(self, x, attention_mask=None):
        # x is a batch of videos, each represented as a tensor of size (num_frames, C, H, W)
        batch_size, num_frames, C, H, W = x.size()
        x = x.view(batch_size * num_frames, C, H, W)  # Reshape for processing by frame_embedding module

        # Generate frame embeddings
        ## In case there are sequences that are too long hence a batch that's too large, we use sub-batches here
        frame_embeddings = []
        subbatch_size = 8
        for i in range(0, batch_size * num_frames, subbatch_size):
            frame_embeddings.append(self.frame_embedding(x[i:(i+subbatch_size)]))
        frame_embeddings = torch.cat(frame_embeddings)
        frame_embeddings = frame_embeddings.view(batch_size, num_frames, -1)  # Reshape back to sequence format

        # Process sequence of embeddings through the Transformer
        action_type_a_output = self.action_type_a_transformer(inputs_embeds=frame_embeddings, attention_mask=attention_mask)
        action_type_b_output = self.action_type_b_transformer(inputs_embeds=frame_embeddings, attention_mask=attention_mask)
        output = torch.cat([action_type_a_output, action_type_b_output], dim=1)

        return output

The code for exporting the model:

model = # Load the trained weights for VideoClassifier()
class ModelWithPreprocessing(nn.Module):
    def __init__(self):
        super(ModelWithPreprocessing, self).__init__()
        self.preprocessing = valid_transform
        self.model = model

    def single_video_preprocessing(self, x):
        x_processed = x.permute(0, 3, 1, 2) # T, H, W, C -> T, C, H, W
        x_processed = self.preprocessing(x_processed)
        
        return x_processed

    def forward(self, x):
        """
        Shape of x: (bs=2, T, H, W, C)
        """
        left_video, right_video = x
        left_video = self.single_video_preprocessing(left_video)
        right_video = self.single_video_preprocessing(right_video)
        
        x_processed = torch.stack([left_video, right_video])
        
        out = self.model(x_processed)
        out[:, :2] = nn.Softmax(dim=-1)(out[:, :2])
        out[:, 2:] = nn.Softmax(dim=-1)(out[:, 2:])
        
        return out

model_with_preprocessing = ModelWithPreprocessing()
model_with_preprocessing.eval()

onnx_output_path = "my_model.onnx"

# Convert the model to ONNX format
# Specify the input shape
input_shape = (2, 1, 360, 640, 3) # bs, T(dynamic), H, W, C

# Create a dummy input tensor
dummy_input = torch.randn(*input_shape, dtype=torch.float32)

# Input and output names
input_names = ["input"]
output_names = ["output"]

# Define the dynamic axes for the input tensor
dynamic_axes = {
    "input": {1: "nb_frames"},
}

# Convert the model to ONNX format
torch.onnx.export(model_with_preprocessing, dummy_input, onnx_output_path, opset_version=14, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)

Finally, some error messages if I pass an array of shape other than (2, 1, 360, 640, 3) to the onnx model:

Array of shape (2, 16, 360, 640, 3):

*** onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running MatMul node. Name:‘/model/we_transformer/transformer_encoder/layers.0/self_attn/MatMul’ Status Message: matmul_helper.h:61 Compute MatMul dimension mismatch

Array of shape (2, 5, 360, 640, 3):

*** onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:‘/model/Reshape_1’ Status Message: /Users/runner/work/1/s/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:36 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape &, onnxruntime::TensorShapeVector &, bool) size != 0 && (input_shape.Size() % size) == 0 was false. The input tensor cannot be reshaped to the requested shape. Input shape:{8,512,1,1}, requested shape:{2,5,-1}