Hi, I’m trying to run the torch to onnx, and generate onnx with functions. But found that torch is deprecating export_modules_as_functions TorchDynamo-based ONNX Exporter — PyTorch 2.8 documentation, does it mean, it will be enabled by default when we do dynamo=True, But when i tried to enable dynamo, I did not see the functions and instead they were decomposed aten ops.
Below is the code I’m using for export as functions which works. But when i enable dynamo, its decomposed into aten, It there a way to maintain the torch to onnx as functions when we enable dynamo? or there are other things to consider or look for?
import torch
import torch.nn as nn
class DecoderBlock(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.linear1 = nn.Linear(hidden_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.activation = nn.ReLU()
self.layer_norm = nn.LayerNorm(hidden_dim)
def forward(self, x):
residual = x
x = self.activation(self.linear1(x))
x = self.linear2(x)
x = self.layer_norm(x + residual)
return x
class MainModel(nn.Module):
def __init__(self, input_dim=8, hidden_dim=16, output_dim=4, num_decoder_blocks=6):
super().__init__()
# Pre-processing layers (inline, not separate modules)
self.input_projection = nn.Linear(input_dim, hidden_dim)
self.encoder_linear1 = nn.Linear(hidden_dim, hidden_dim)
self.encoder_linear2 = nn.Linear(hidden_dim, hidden_dim)
self.attention_query = nn.Linear(hidden_dim, hidden_dim)
self.attention_key = nn.Linear(hidden_dim, hidden_dim)
self.attention_value = nn.Linear(hidden_dim, hidden_dim)
# Multiple independent decoder blocks (these will be functions)
for i in range(num_decoder_blocks):
setattr(self, f'decoder_block_{i}', DecoderBlock(hidden_dim))
# Post-processing layers (inline, not separate modules)
self.output_linear1 = nn.Linear(hidden_dim, hidden_dim // 2)
self.output_linear2 = nn.Linear(hidden_dim // 2, output_dim)
self.num_decoder_blocks = num_decoder_blocks
self.activation = nn.ReLU()
def forward(self, x):
# Pre-processing (inline operations)
x = self.input_projection(x)
# Encoder-like processing
x = self.activation(self.encoder_linear1(x))
x = self.activation(self.encoder_linear2(x))
# Attention-like processing
q = self.attention_query(x)
k = self.attention_key(x)
v = self.attention_value(x)
attention_weights = torch.softmax(torch.matmul(q, k.transpose(-2, -1)), dim=-1)
x = torch.matmul(attention_weights, v)
# Pass through each decoder block (these will be separate functions)
for i in range(self.num_decoder_blocks):
decoder_block = getattr(self, f'decoder_block_{i}')
x = decoder_block(x)
# Post-processing (inline operations)
x = self.activation(self.output_linear1(x))
output = self.output_linear2(x)
return output
# Create model with 6 decoder blocks
model = MainModel(num_decoder_blocks=6)
sample_input = torch.rand((32, 8), dtype=torch.float32)
# Export to ONNX with modules as functions
torch.onnx.export(
model,
sample_input,
"decoder_functions_model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
export_modules_as_functions=True,
verbose=True
)