Hey,
I’m encountering a persistent RuntimeError: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run
when using torch.compile
(specifically with mode="max-autotune"
) on my transformer model. This error appears to be related to residual connections within my Block
module, even after implementing common solutions.
Environment:
- PyTorch version:
2.7.1+cu126
- CUDA version:
12.9
- GPU:
NVIDIA GeForce GTX 1650 Mobile / Max-Q
- Operating System:
Linux Mint 22.1 x86_64
- Python version:
3.12
bitsandbytes
is used for 4-bit linear layers.
Model Architecture Snippets (simplified for relevance):
class Block(nn.Module):
def __init__(self, d_model: int, n_heads: int, ffn_hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.attention_norm = nn.RMSNorm(d_model)
self.self_attn = MultiHeadAttention(n_embd, n_heads, block_size, dropout)
self.ffn_norm = nn.RMSNorm(d_model)
self.ffn = FeedForward(d_model, ffn_hidden_dim, dropout)
self.residual_dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# First residual connection
x = x + self.residual_dropout(self.self_attn(self.attention_norm(x)))
# Second residual connection
x = x + self.residual_dropout(self.ffn(self.ffn_norm(x)))
return x
class LanguageModel(nn.Module):
def __init__(self):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
ffn_hidden_dim = int(4 * n_embd * (2/3))
self.blocks = nn.ModuleList([
Block(d_model=n_embd, n_heads=n_head, ffn_hidden_dim=ffn_hidden_dim, dropout=dropout)
for _ in range(n_layer)
])
self.ln_f = nn.RMSNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
self.token_embedding_table.weight = self.lm_head.weight
def forward(self, idx, targets=None, smoothing=0.0):
x = self.token_embedding_table(idx)
for block in self.blocks:
x = block(x) # This loop is likely where the issue originates within the graph
x = self.ln_f(x)
logits = self.lm_head(x)
# ... loss calculation ...
return logits, loss
Training Loop Snippet:
# ... (inside main function) ...
model = LanguageModel().to(device)
model = torch.compile(model, mode="max-autotune") # Compilation happens here
optimizer.zero_grad(set_to_none=True)
while iter_num < max_iters:
try:
xb, yb = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
xb, yb = next(train_iter)
xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
torch.compiler.cudagraph_mark_step_begin()
with autocast(device_type=device, dtype=pt_dtype):
_, loss = model(xb, yb, smoothing=label_smoothing)
loss = loss / gradient_accumulation_steps
scaler.scale(loss).backward()
Problem Description:
I’m encountering
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/home/lixt/Desktop/ㅤ/projects/AI/CHAT/train.py", line 246, in forward
norm_x = self.attention_norm(x)
File "/home/lixt/.local/lib/python3.12/site-packages/torch/nn/functional.py", line 2929, in rms_norm
return torch.rms_norm(input, normalized_shape, weight, eps). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.
every time I use mode='max-autotune'
or reduce-overhead
.
The stack trace points to RMSNorm inside the Block.
What I’ve tried (and did not fix it):
-
Refactoring Block.forward : Changed h = x + … to x = x + … to ensure explicit data flow and avoid potential in-place issues within the compiled graph.
-
torch.compiler.cudagraph_mark_step_begin() : Placed this at the beginning of each training step before the model call, which is standard practice for CUDA graphs and torch.compile.
-
Adding .clone() : I initially tried adding .clone() on the x input to the block or before the residual addition, but this did not resolve the issue and often resulted in other graph-related errors or performance degradation.
Given that the error persists even after restructuring the Block’s forward pass to avoid explicit intermediate variables and using cudagraph_mark_step_begin
, I’m looking for help to understand why torch.compile is still perceiving an overwrite issue within the graph, especially when the stack trace points to the input of a normalized layer and how I could fix it. Thank you.