Performance bottleneck with autocast in AOT Inductor export: repeated weight casting due to disabled cache

Problem Description

When I export my model with AOT Inductor, I notice that weights are being repeatedly cast to the target precision (fp16/bf16) on every forward pass, which creates a significant performance bottleneck.

Root Cause Analysis

After investigating the source code, I found that:

  1. In normal eager mode, torch.autocast has cache_enabled=True by default (torch/amp/autocast_mode.py:214-215), which caches weight conversions and avoids redundant casting.

  2. However, during AOT compilation, the cache is explicitly disabled in torch/_dynamo/compiled_autograd.py:270:

    self.stack.enter_context(disable_autocast_cache())
    
    
  3. When the model is exported via torch.export.export(), autocast regions are converted to wrap_with_autocast Higher Order Operators (torch/_export/passes/replace_autocast_with_hop_pass.py), but the cache remains disabled in the exported graph.

Minimal Reproducible Example

import torch
import torch.nn as nn

class SubModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(128, 256)
        self.linear2 = nn.Linear(256, 128)
    
    def forward(self, x):
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            x = self.linear1(x)
            x = torch.relu(x)
            x = self.linear2(x)
        return x

class MainModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.sub = SubModule()
        self.output = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.sub(x)
        x = self.output(x)
        return x

# Export with AOT Inductor
model = MainModel().cuda()
example_input = torch.randn(4, 128, device="cuda")
exported_program = torch.export.export(model, (example_input,))

# The exported model shows repeated casting operations in profiling

Questions

  1. Is there a way to enable autocast caching in AOT-compiled models to avoid repeated weight casting during inference?

  2. What’s the recommended approach to achieve the performance benefits of cached weight conversion in exported models? Should I:

    • Pre-convert weights to the target dtype before export?

    • Use a custom export pass to optimize autocast regions?

    • Use a different export/compilation strategy?