CUDA Graph Error with Residual Connections in `torch.compile` (RuntimeError: accessing tensor output of CUDAGraphs)

Hey,
I’m encountering a persistent RuntimeError: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run when using torch.compile (specifically with mode="max-autotune") on my transformer model. This error appears to be related to residual connections within my Block module, even after implementing common solutions.

Environment:

  • PyTorch version: 2.7.1+cu126
  • CUDA version: 12.9
  • GPU: NVIDIA GeForce GTX 1650 Mobile / Max-Q
  • Operating System: Linux Mint 22.1 x86_64
  • Python version: 3.12
  • bitsandbytes is used for 4-bit linear layers.

Model Architecture Snippets (simplified for relevance):

class Block(nn.Module):
    def __init__(self, d_model: int, n_heads: int, ffn_hidden_dim: int, dropout: float = 0.0):
        super().__init__()
        self.attention_norm = nn.RMSNorm(d_model)
        self.self_attn = MultiHeadAttention(n_embd, n_heads, block_size, dropout)
        self.ffn_norm = nn.RMSNorm(d_model)
        self.ffn = FeedForward(d_model, ffn_hidden_dim, dropout)
        self.residual_dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # First residual connection
        x = x + self.residual_dropout(self.self_attn(self.attention_norm(x)))
        # Second residual connection
        x = x + self.residual_dropout(self.ffn(self.ffn_norm(x)))
        return x

class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        ffn_hidden_dim = int(4 * n_embd * (2/3))
        self.blocks = nn.ModuleList([
            Block(d_model=n_embd, n_heads=n_head, ffn_hidden_dim=ffn_hidden_dim, dropout=dropout)
            for _ in range(n_layer)
        ])
        self.ln_f = nn.RMSNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.token_embedding_table.weight = self.lm_head.weight

    def forward(self, idx, targets=None, smoothing=0.0):
        x = self.token_embedding_table(idx)
        for block in self.blocks:
            x = block(x) # This loop is likely where the issue originates within the graph
        x = self.ln_f(x)
        logits = self.lm_head(x)
        # ... loss calculation ...
        return logits, loss

Training Loop Snippet:

# ... (inside main function) ...
model = LanguageModel().to(device)
model = torch.compile(model, mode="max-autotune") # Compilation happens here

    optimizer.zero_grad(set_to_none=True)

while iter_num < max_iters:
    try:
        xb, yb = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        xb, yb = next(train_iter)

    xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)

    torch.compiler.cudagraph_mark_step_begin()

    with autocast(device_type=device, dtype=pt_dtype):
        _, loss = model(xb, yb, smoothing=label_smoothing)
        loss = loss / gradient_accumulation_steps

    scaler.scale(loss).backward()

Problem Description:

I’m encountering

RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/home/lixt/Desktop/ㅤ/projects/AI/CHAT/train.py", line 246, in forward
    norm_x = self.attention_norm(x)
  File "/home/lixt/.local/lib/python3.12/site-packages/torch/nn/functional.py", line 2929, in rms_norm
    return torch.rms_norm(input, normalized_shape, weight, eps). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

every time I use mode='max-autotune' or reduce-overhead.

The stack trace points to RMSNorm inside the Block.

What I’ve tried (and did not fix it):

  1. Refactoring Block.forward : Changed h = x + … to x = x + … to ensure explicit data flow and avoid potential in-place issues within the compiled graph.

  2. torch.compiler.cudagraph_mark_step_begin() : Placed this at the beginning of each training step before the model call, which is standard practice for CUDA graphs and torch.compile.

  3. Adding .clone() : I initially tried adding .clone() on the x input to the block or before the residual addition, but this did not resolve the issue and often resulted in other graph-related errors or performance degradation.

Given that the error persists even after restructuring the Block’s forward pass to avoid explicit intermediate variables and using cudagraph_mark_step_begin, I’m looking for help to understand why torch.compile is still perceiving an overwrite issue within the graph, especially when the stack trace points to the input of a normalized layer and how I could fix it. Thank you.