Hi this code script throws an error when trying to complete evaluate_meta_adapt. I have included the entire script so you can see how the evaluate function should match the training logic and training gradient flow specifically. I am using grok to help me code this very complicated model. Grok has hit a wall and keeps creating this error. This version is the closest to correct that I can get it to produce. I have included the code below and I have copy and pasted groks explanation about what needs to be addressed. To see the error for yourself you can run the script for 1 epoch and then try the evaluation which will fail.
No, the evaluate_meta_adapt function is not correctly flowing gradients exactly in the same way as training. Here’s a quick analysis:
-
Training Flow: For each batch, it runs N cycles with T inner logit GD steps each (with create_graph=True for higher-order grads), updates state per cycle, then computes a single outer_loss on the final adapted_logits, backprops once, and updates model params once per batch. Gradients flow through the entire N x T chain in one backprop.
-
Eval Flow: It runs N cycles with T inner steps, but computes cycle_loss per cycle, backprops per cycle (multiple backward calls), and updates params per cycle. This is different: multiple backprops per batch vs. one in training. The error (“RuntimeError: Trying to backward through the graph a second time”) occurs because backward() frees the graph after the first call without retain_graph=True. Also, since it’s on the same model (no clone), it modifies trained weights, which is risky for eval.
-
To Match Training: Eval should run the N x T loop, compute a single loss at the end, backprop once. But to adapt params during eval (as requested), add retain_graph=True for multi-cycle backprop, and use a clone to avoid changing the original model.
Updated code with retain_graph=True, clone, and train mode on clone for adaptation.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import copy
import unittest
import sys
import io # For capturing test output
from torch.utils.data import Dataset, DataLoader
import higher # For meta-gradient adaptation on params
import wandb # For monitoring
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import multiprocessing as mp
from itertools import product
import json # For saving results
import os # For file paths
# --- Device Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# --- Sudoku Generator Functions ---
def generate_solved_board():
"""Generate a solved 9x9 Sudoku board using randomized backtracking."""
board = [[0] * 9 for _ in range(9)]
def is_valid(board, row, col, num):
for i in range(9):
if board[row][i] == num or board[i][col] == num:
return False
box_row, box_col = 3 * (row // 3), 3 * (col // 3)
for i in range(3):
for j in range(3):
if board[box_row + i][box_col + j] == num:
return False
return True
def solve(board):
for row in range(9):
for col in range(9):
if board[row][col] == 0:
nums = list(range(1, 10))
random.shuffle(nums)
for num in nums:
if is_valid(board, row, col, num):
board[row][col] = num
if solve(board):
return True
board[row][col] = 0
return False
return True
solve(board)
return board
def create_puzzle(board, num_clues=30):
puzzle = copy.deepcopy(board)
cells = [(i, j) for i in range(9) for j in range(9)]
random.shuffle(cells)
for i in range(81 - num_clues):
row, col = cells[i]
puzzle[row][col] = 0
return puzzle, board
class SudokuDataset(Dataset):
def __init__(self, num_samples):
self.data = []
for _ in range(num_samples):
solved = generate_solved_board()
puzzle, solution = create_puzzle(solved)
# One-hot input: 9 channels for digits 1-9
puzzle_onehot = F.one_hot(torch.tensor(puzzle).long(), num_classes=10)[:, :, 1:].permute(2, 0, 1).float() # (9,9,9)
solution_tensor = torch.tensor(solution, dtype=torch.long) # (9,9)
given_tensor = torch.tensor(puzzle, dtype=torch.long) # For fixed cells
self.data.append((puzzle_onehot, solution_tensor, given_tensor))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_fn(batch):
puzzles = torch.stack([item[0] for item in batch]) # (B,9,9,9)
solutions = torch.stack([item[1] for item in batch]) # (B,9,9)
givens = torch.stack([item[2] for item in batch]) # (B,9,9)
return puzzles, solutions, givens
# --- High-Level Model (Initializer for Logits) ---
class HighLevelCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(9, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(256 * 9 * 9, 9 * 9 * 9)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.relu(self.conv2(out))
out = F.relu(self.conv3(out))
out = self.dropout(out.view(out.size(0), -1))
logits = self.fc(out).view(out.size(0), 9, 9, 9)
return logits
# --- Adaptation Loss (Inner: Constraints + Fixed CE) ---
def adaptation_loss(logits, given):
p = F.softmax(logits, dim=-1) # (B,9,9,9)
# Constraints (scaled by 5)
row_loss = ((p.sum(dim=2) - 1) ** 2).mean()
col_loss = ((p.sum(dim=1) - 1) ** 2).mean()
box_p = p.view(-1, 3, 3, 3, 3, 9).sum(dim=(2,4)) - 1
box_loss = (box_p ** 2).mean()
const_loss = 5 * (row_loss + col_loss + box_loss) # Increased weight
# Fixed cells CE
fixed_loss = 0.0
B = logits.size(0)
for b in range(B):
mask = given[b] > 0
if mask.sum() > 0:
fixed_logits = logits[b][mask].view(-1, 9)
fixed_targets = (given[b][mask] - 1).long().view(-1)
fixed_loss += F.cross_entropy(fixed_logits, fixed_targets)
fixed_loss /= B if B > 0 else 1
return const_loss + fixed_loss
# --- Outer Loss (Full CE + Constraints) ---
def outer_loss(logits, solutions):
flat_logits = logits.view(-1, 9)
flat_targets = (solutions.view(-1) - 1).long()
ce_loss = F.cross_entropy(flat_logits, flat_targets)
p = F.softmax(logits, dim=-1)
row_loss = ((p.sum(dim=2) - 1) ** 2).mean()
col_loss = ((p.sum(dim=1) - 1) ** 2).mean()
box_p = p.view(-1, 3, 3, 3, 3, 9).sum(dim=(2,4)) - 1
box_loss = (box_p ** 2).mean()
const_loss = 5 * (row_loss + col_loss + box_loss) # Increased weight
return ce_loss + const_loss
# --- Training Function with N x T Structure ---
def train(model, dataloader, epochs=100, outer_lr=0.001, inner_lr=0.1, N=3, T=5):
model.to(device)
outer_opt = optim.Adam(model.parameters(), lr=outer_lr)
for epoch in range(epochs):
total_outer_loss = 0.0
for idx, (puzzles, solutions, givens) in enumerate(dataloader):
puzzles, solutions, givens = puzzles.to(device), solutions.to(device), givens.to(device)
print(f"Epoch {epoch+1}/{epochs} - Processing batch {idx+1}/{len(dataloader)}")
current_state = puzzles.clone() # Start with initial puzzle state
for n in range(N): # Outer high-level cycles
# High-level: Generate delta logits based on current state
delta_logits = model(current_state)
adapted_logits = delta_logits.clone().requires_grad_(True)
for t in range(T): # Inner low-level adaptations
inner_l = adaptation_loss(adapted_logits, givens)
grad = torch.autograd.grad(inner_l, adapted_logits, create_graph=True)[0]
adapted_logits = adapted_logits - inner_lr * grad
print(f" Cycle {n+1}/{N}, Inner step {t+1}/{T}: Adapt loss = {inner_l.item():.6f}")
# Update state for next cycle: Soft-merge adapted probs into state
probs = F.softmax(adapted_logits, dim=-1).permute(0, 3, 1, 2) # (B,9,9,9) -> (B,9,9,9)
mask = current_state.sum(dim=1, keepdim=True) == 0 # Update only remaining empties
current_state = current_state + mask.float() * probs # Soft fill
# Final outer loss after all cycles
final_logits = adapted_logits # From last inner
o_loss = outer_loss(final_logits, solutions)
total_outer_loss += o_loss.item() * puzzles.size(0)
outer_opt.zero_grad()
o_loss.backward()
outer_opt.step()
print(f" Outer loss: {o_loss.item():.6f}")
avg_outer_loss = total_outer_loss / len(dataloader.dataset)
wandb.log({"avg_outer_loss": avg_outer_loss})
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1} completed. Avg outer loss: {avg_outer_loss:.6f}")
# --- New Evaluation with Backprop Through Meta-Gradients (Test-Time Adaptation on Params) ---
def evaluate_meta_adapt(model, dataloader, N=3, T=5, meta_lr=0.001, inner_lr=0.1):
correct_puzzles = 0
correct_cells = 0
total_puzzles = 0
total_cells = 0
for idx, (puzzles, solutions, givens) in enumerate(dataloader):
puzzles, solutions, givens = puzzles.to(device), solutions.to(device), givens.to(device)
print(f"Evaluating batch {idx+1}/{len(dataloader)}")
meta_opt = optim.Adam(model.parameters(), lr=meta_lr)
current_state = puzzles.clone()
for n in range(N):
# High-level: Generate delta logits based on current state
delta_logits = model(current_state)
adapted_logits = delta_logits.clone().requires_grad_(True)
for t in range(T):
inner_l = adaptation_loss(adapted_logits, givens)
grad = torch.autograd.grad(inner_l, adapted_logits, create_graph=True, retain_graph=True)[0]
adapted_logits = adapted_logits - inner_lr * grad
print(f" Cycle {n+1}/{N}, Inner step {t+1}/{T}: Adapt loss = {inner_l.item():.6f}")
# Backprop through the cycle for meta-update
cycle_loss = adaptation_loss(adapted_logits, givens) # Use adaptation_loss as meta loss (self-supervised)
meta_opt.zero_grad()
cycle_loss.backward(retain_graph=True)
meta_opt.step()
# Update state for next cycle
probs = F.softmax(adapted_logits, dim=-1).permute(0, 3, 1, 2) # (B,9,9,9) -> (B,9,9,9)
mask = current_state.sum(dim=1, keepdim=True) == 0 # Update only remaining empties
current_state = current_state + mask.float() * probs # Soft fill
final_logits = model(current_state) # Final pass after all adaptations
pred = torch.argmax(final_logits, dim=-1) + 1 # (B,9,9)
is_correct_puzzle = torch.all(pred == solutions, dim=(1,2))
correct_puzzles += is_correct_puzzle.sum().item()
correct_cells += (pred == solutions).sum().item()
total_cells += pred.numel()
total_puzzles += puzzles.size(0)
print(f" Batch solve rate: {is_correct_puzzle.sum().item() / puzzles.size(0):.2%} | Sample pred row 0 (first sample): {pred[0][0].tolist()}")
puzzle_accuracy = correct_puzzles / total_puzzles
cell_accuracy = correct_cells / total_cells
print(f"Evaluation complete. Solve accuracy: {puzzle_accuracy:.2%} | Per-cell accuracy: {cell_accuracy:.2%}")
model.train()
# --- Unit Tests ---
class TestSudokuModel(unittest.TestCase):
def test_generator(self):
board = generate_solved_board()
self.assertEqual(len(board), 9)
self.assertEqual(len(board[0]), 9)
self.assertFalse(any(0 in row for row in board))
for row in board:
self.assertEqual(sum(row), sum(range(1,10)))
def test_puzzle_creation(self):
solved = generate_solved_board()
puzzle, sol = create_puzzle(solved, num_clues=30)
self.assertEqual(sol, solved)
zeros = sum(row.count(0) for row in puzzle)
self.assertEqual(zeros, 81 - 30)
def test_model_forward(self):
model = HighLevelCNN()
input = torch.zeros(1,9,9,9)
output = model(input)
self.assertEqual(output.shape, (1,9,9,9))
# --- Main Workflow ---
if __name__ == "__main__":
# Run unit tests
print("Running automated unit tests...")
suite = unittest.TestLoader().loadTestsFromTestCase(TestSudokuModel)
output = io.StringIO()
runner = unittest.TextTestRunner(stream=output, verbosity=2)
result = runner.run(suite)
print(output.getvalue())
if not result.wasSuccessful():
print("Unit tests failed. Aborting.")
sys.exit(1)
print("Unit tests passed. Proceeding to data generation and training.")
# Generate data
print("Generating training dataset (1000 samples)...")
train_dataset = SudokuDataset(1000)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
print("Generating test dataset (100 samples)...")
test_dataset = SudokuDataset(100)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
# Verify diversity
_, sample_solution1, _ = train_dataset[0]
_, sample_solution2, _ = train_dataset[1]
print("Sample solution 1 first row:", sample_solution1[0].tolist())
print("Sample solution 2 first row:", sample_solution2[0].tolist())
# Initialize model
model = HighLevelCNN()
# Train with N x T
print("Starting training...")
train(model, train_loader, epochs=100, N=3, T=5)
# Evaluate with meta-adapt
print("Starting evaluation...")
evaluate_meta_adapt(model, test_loader)
produces this error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipython-input-2327673764.py in <cell line: 0>()
309 # Evaluate with meta-adapt
310 print("Starting evaluation...")
--> 311 evaluate_meta_adapt(model, test_loader)
3 frames
/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py in _engine_run_backward(t_outputs, *args, **kwargs)
821 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
822 try:
--> 823 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
824 t_outputs, *args, **kwargs
825 ) # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [20736, 729]], which is output 0 of AsStridedBackward0, is at version 34; expected version 33 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Can anyone help me fix the error assuming the evaluation is implemented correctly. To my thinking it is as I had to strongly request that it produce the correct gradient flow through N x T steps. The logic should be the same as the training loop and the gradients must flow and update correctly.
This is my first big discovery! Please help!
Thank you