Problem Description
When I export my model with AOT Inductor, I notice that weights are being repeatedly cast to the target precision (fp16/bf16) on every forward pass, which creates a significant performance bottleneck.
Root Cause Analysis
After investigating the source code, I found that:
-
In normal eager mode,
torch.autocasthascache_enabled=Trueby default (torch/amp/autocast_mode.py:214-215), which caches weight conversions and avoids redundant casting. -
However, during AOT compilation, the cache is explicitly disabled in
torch/_dynamo/compiled_autograd.py:270:self.stack.enter_context(disable_autocast_cache()) -
When the model is exported via
torch.export.export(), autocast regions are converted towrap_with_autocastHigher Order Operators (torch/_export/passes/replace_autocast_with_hop_pass.py), but the cache remains disabled in the exported graph.
Minimal Reproducible Example
import torch
import torch.nn as nn
class SubModule(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 128)
def forward(self, x):
with torch.autocast(device_type="cuda", dtype=torch.float16):
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
return x
class MainModel(nn.Module):
def __init__(self):
super().__init__()
self.sub = SubModule()
self.output = nn.Linear(128, 10)
def forward(self, x):
x = self.sub(x)
x = self.output(x)
return x
# Export with AOT Inductor
model = MainModel().cuda()
example_input = torch.randn(4, 128, device="cuda")
exported_program = torch.export.export(model, (example_input,))
# The exported model shows repeated casting operations in profiling
Questions
-
Is there a way to enable autocast caching in AOT-compiled models to avoid repeated weight casting during inference?
-
What’s the recommended approach to achieve the performance benefits of cached weight conversion in exported models? Should I:
-
Pre-convert weights to the target dtype before export?
-
Use a custom export pass to optimize autocast regions?
-
Use a different export/compilation strategy?
-