Export torch to onnx as functions

Hi, I’m trying to run the torch to onnx, and generate onnx with functions. But found that torch is deprecating export_modules_as_functions TorchDynamo-based ONNX Exporter — PyTorch 2.8 documentation, does it mean, it will be enabled by default when we do dynamo=True, But when i tried to enable dynamo, I did not see the functions and instead they were decomposed aten ops.

Below is the code I’m using for export as functions which works. But when i enable dynamo, its decomposed into aten, It there a way to maintain the torch to onnx as functions when we enable dynamo? or there are other things to consider or look for?

import torch
import torch.nn as nn

class DecoderBlock(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.activation = nn.ReLU()
        self.layer_norm = nn.LayerNorm(hidden_dim)
        
    def forward(self, x):
        residual = x
        x = self.activation(self.linear1(x))
        x = self.linear2(x)
        x = self.layer_norm(x + residual)
        return x

class MainModel(nn.Module):
    def __init__(self, input_dim=8, hidden_dim=16, output_dim=4, num_decoder_blocks=6):
        super().__init__()
        # Pre-processing layers (inline, not separate modules)
        self.input_projection = nn.Linear(input_dim, hidden_dim)
        self.encoder_linear1 = nn.Linear(hidden_dim, hidden_dim)
        self.encoder_linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.attention_query = nn.Linear(hidden_dim, hidden_dim)
        self.attention_key = nn.Linear(hidden_dim, hidden_dim)
        self.attention_value = nn.Linear(hidden_dim, hidden_dim)
        
        # Multiple independent decoder blocks (these will be functions)
        for i in range(num_decoder_blocks):
            setattr(self, f'decoder_block_{i}', DecoderBlock(hidden_dim))
        
        # Post-processing layers (inline, not separate modules)
        self.output_linear1 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.output_linear2 = nn.Linear(hidden_dim // 2, output_dim)
        
        self.num_decoder_blocks = num_decoder_blocks
        self.activation = nn.ReLU()
        
    def forward(self, x):
        # Pre-processing (inline operations)
        x = self.input_projection(x)
        
        # Encoder-like processing
        x = self.activation(self.encoder_linear1(x))
        x = self.activation(self.encoder_linear2(x))
        
        # Attention-like processing
        q = self.attention_query(x)
        k = self.attention_key(x)
        v = self.attention_value(x)
        attention_weights = torch.softmax(torch.matmul(q, k.transpose(-2, -1)), dim=-1)
        x = torch.matmul(attention_weights, v)
        
        # Pass through each decoder block (these will be separate functions)
        for i in range(self.num_decoder_blocks):
            decoder_block = getattr(self, f'decoder_block_{i}')
            x = decoder_block(x)
        
        # Post-processing (inline operations)
        x = self.activation(self.output_linear1(x))
        output = self.output_linear2(x)
        
        return output

# Create model with 6 decoder blocks
model = MainModel(num_decoder_blocks=6)
sample_input = torch.rand((32, 8), dtype=torch.float32)

# Export to ONNX with modules as functions
torch.onnx.export(
    model,
    sample_input,
    "decoder_functions_model.onnx",
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    export_modules_as_functions=True,
    verbose=True
)

@ptrblck

The feature is no longer supported, but it is possible that you create your own graph transformation to obtain this result. What is your use case for needing the functions?

@justinchuby thanks for the response, How can we create the graph to obtain? right now we are doing specific modules as function with::

export_modules_as_functions={MODULE_NAME}

The requirement is to have repeated blocks with single unique function proto signature, so that its easier to optimize the graph.

I don’t think the old exporter will produce reused function definitions in the graph? IIRC it will only wrap nn.Module scoped ops into functions, but it will not collapse shared logic.

With the new exporter, you may consider leveraging the metadata embedded to recreate the hierarchy: torch.export-based ONNX Exporter — PyTorch 2.9 documentation