Higher library and create/reatain graph error

Hi this code script throws an error when trying to complete evaluate_meta_adapt. I have included the entire script so you can see how the evaluate function should match the training logic and training gradient flow specifically. I am using grok to help me code this very complicated model. Grok has hit a wall and keeps creating this error. This version is the closest to correct that I can get it to produce. I have included the code below and I have copy and pasted groks explanation about what needs to be addressed. To see the error for yourself you can run the script for 1 epoch and then try the evaluation which will fail.

No, the evaluate_meta_adapt function is not correctly flowing gradients exactly in the same way as training. Here’s a quick analysis:

  • Training Flow: For each batch, it runs N cycles with T inner logit GD steps each (with create_graph=True for higher-order grads), updates state per cycle, then computes a single outer_loss on the final adapted_logits, backprops once, and updates model params once per batch. Gradients flow through the entire N x T chain in one backprop.

  • Eval Flow: It runs N cycles with T inner steps, but computes cycle_loss per cycle, backprops per cycle (multiple backward calls), and updates params per cycle. This is different: multiple backprops per batch vs. one in training. The error (“RuntimeError: Trying to backward through the graph a second time”) occurs because backward() frees the graph after the first call without retain_graph=True. Also, since it’s on the same model (no clone), it modifies trained weights, which is risky for eval.

  • To Match Training: Eval should run the N x T loop, compute a single loss at the end, backprop once. But to adapt params during eval (as requested), add retain_graph=True for multi-cycle backprop, and use a clone to avoid changing the original model.

Updated code with retain_graph=True, clone, and train mode on clone for adaptation.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import copy
import unittest
import sys
import io  # For capturing test output
from torch.utils.data import Dataset, DataLoader
import higher  # For meta-gradient adaptation on params
import wandb  # For monitoring
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import multiprocessing as mp
from itertools import product
import json  # For saving results
import os  # For file paths

# --- Device Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- Sudoku Generator Functions ---
def generate_solved_board():
    """Generate a solved 9x9 Sudoku board using randomized backtracking."""
    board = [[0] * 9 for _ in range(9)]
    
    def is_valid(board, row, col, num):
        for i in range(9):
            if board[row][i] == num or board[i][col] == num:
                return False
        box_row, box_col = 3 * (row // 3), 3 * (col // 3)
        for i in range(3):
            for j in range(3):
                if board[box_row + i][box_col + j] == num:
                    return False
        return True
    
    def solve(board):
        for row in range(9):
            for col in range(9):
                if board[row][col] == 0:
                    nums = list(range(1, 10))
                    random.shuffle(nums)
                    for num in nums:
                        if is_valid(board, row, col, num):
                            board[row][col] = num
                            if solve(board):
                                return True
                            board[row][col] = 0
                    return False
        return True
    
    solve(board)
    return board

def create_puzzle(board, num_clues=30):
    puzzle = copy.deepcopy(board)
    cells = [(i, j) for i in range(9) for j in range(9)]
    random.shuffle(cells)
    for i in range(81 - num_clues):
        row, col = cells[i]
        puzzle[row][col] = 0
    return puzzle, board

class SudokuDataset(Dataset):
    def __init__(self, num_samples):
        self.data = []
        for _ in range(num_samples):
            solved = generate_solved_board()
            puzzle, solution = create_puzzle(solved)
            # One-hot input: 9 channels for digits 1-9
            puzzle_onehot = F.one_hot(torch.tensor(puzzle).long(), num_classes=10)[:, :, 1:].permute(2, 0, 1).float()  # (9,9,9)
            solution_tensor = torch.tensor(solution, dtype=torch.long)  # (9,9)
            given_tensor = torch.tensor(puzzle, dtype=torch.long)  # For fixed cells
            self.data.append((puzzle_onehot, solution_tensor, given_tensor))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    puzzles = torch.stack([item[0] for item in batch])  # (B,9,9,9)
    solutions = torch.stack([item[1] for item in batch])  # (B,9,9)
    givens = torch.stack([item[2] for item in batch])  # (B,9,9)
    return puzzles, solutions, givens

# --- High-Level Model (Initializer for Logits) ---
class HighLevelCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(9, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(256 * 9 * 9, 9 * 9 * 9)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        out = self.dropout(out.view(out.size(0), -1))
        logits = self.fc(out).view(out.size(0), 9, 9, 9)
        return logits

# --- Adaptation Loss (Inner: Constraints + Fixed CE) ---
def adaptation_loss(logits, given):
    p = F.softmax(logits, dim=-1)  # (B,9,9,9)
    
    # Constraints (scaled by 5)
    row_loss = ((p.sum(dim=2) - 1) ** 2).mean()
    col_loss = ((p.sum(dim=1) - 1) ** 2).mean()
    box_p = p.view(-1, 3, 3, 3, 3, 9).sum(dim=(2,4)) - 1
    box_loss = (box_p ** 2).mean()
    const_loss = 5 * (row_loss + col_loss + box_loss)  # Increased weight
    
    # Fixed cells CE
    fixed_loss = 0.0
    B = logits.size(0)
    for b in range(B):
        mask = given[b] > 0
        if mask.sum() > 0:
            fixed_logits = logits[b][mask].view(-1, 9)
            fixed_targets = (given[b][mask] - 1).long().view(-1)
            fixed_loss += F.cross_entropy(fixed_logits, fixed_targets)
    fixed_loss /= B if B > 0 else 1
    
    return const_loss + fixed_loss

# --- Outer Loss (Full CE + Constraints) ---
def outer_loss(logits, solutions):
    flat_logits = logits.view(-1, 9)
    flat_targets = (solutions.view(-1) - 1).long()
    ce_loss = F.cross_entropy(flat_logits, flat_targets)
    
    p = F.softmax(logits, dim=-1)
    row_loss = ((p.sum(dim=2) - 1) ** 2).mean()
    col_loss = ((p.sum(dim=1) - 1) ** 2).mean()
    box_p = p.view(-1, 3, 3, 3, 3, 9).sum(dim=(2,4)) - 1
    box_loss = (box_p ** 2).mean()
    const_loss = 5 * (row_loss + col_loss + box_loss)  # Increased weight
    
    return ce_loss + const_loss

# --- Training Function with N x T Structure ---
def train(model, dataloader, epochs=100, outer_lr=0.001, inner_lr=0.1, N=3, T=5):
    model.to(device)
    outer_opt = optim.Adam(model.parameters(), lr=outer_lr)
    
    for epoch in range(epochs):
        total_outer_loss = 0.0
        for idx, (puzzles, solutions, givens) in enumerate(dataloader):
            puzzles, solutions, givens = puzzles.to(device), solutions.to(device), givens.to(device)
            print(f"Epoch {epoch+1}/{epochs} - Processing batch {idx+1}/{len(dataloader)}")
            
            current_state = puzzles.clone()  # Start with initial puzzle state
            
            for n in range(N):  # Outer high-level cycles
                # High-level: Generate delta logits based on current state
                delta_logits = model(current_state)
                
                adapted_logits = delta_logits.clone().requires_grad_(True)
                
                for t in range(T):  # Inner low-level adaptations
                    inner_l = adaptation_loss(adapted_logits, givens)
                    grad = torch.autograd.grad(inner_l, adapted_logits, create_graph=True)[0]
                    adapted_logits = adapted_logits - inner_lr * grad
                    print(f"    Cycle {n+1}/{N}, Inner step {t+1}/{T}: Adapt loss = {inner_l.item():.6f}")
                
                # Update state for next cycle: Soft-merge adapted probs into state
                probs = F.softmax(adapted_logits, dim=-1).permute(0, 3, 1, 2)  # (B,9,9,9) -> (B,9,9,9)
                mask = current_state.sum(dim=1, keepdim=True) == 0  # Update only remaining empties
                current_state = current_state + mask.float() * probs  # Soft fill
                
            # Final outer loss after all cycles
            final_logits = adapted_logits  # From last inner
            o_loss = outer_loss(final_logits, solutions)
            total_outer_loss += o_loss.item() * puzzles.size(0)
            
            outer_opt.zero_grad()
            o_loss.backward()
            outer_opt.step()
            
            print(f"  Outer loss: {o_loss.item():.6f}")
        
        avg_outer_loss = total_outer_loss / len(dataloader.dataset)
        wandb.log({"avg_outer_loss": avg_outer_loss})
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1} completed. Avg outer loss: {avg_outer_loss:.6f}")

# --- New Evaluation with Backprop Through Meta-Gradients (Test-Time Adaptation on Params) ---
def evaluate_meta_adapt(model, dataloader, N=3, T=5, meta_lr=0.001, inner_lr=0.1):
    correct_puzzles = 0
    correct_cells = 0
    total_puzzles = 0
    total_cells = 0

    for idx, (puzzles, solutions, givens) in enumerate(dataloader):
        puzzles, solutions, givens = puzzles.to(device), solutions.to(device), givens.to(device)
        print(f"Evaluating batch {idx+1}/{len(dataloader)}")

        meta_opt = optim.Adam(model.parameters(), lr=meta_lr)

        current_state = puzzles.clone()

        for n in range(N):
            # High-level: Generate delta logits based on current state
            delta_logits = model(current_state)
            
            adapted_logits = delta_logits.clone().requires_grad_(True)
            
            for t in range(T):
                inner_l = adaptation_loss(adapted_logits, givens)
                grad = torch.autograd.grad(inner_l, adapted_logits, create_graph=True, retain_graph=True)[0]
                adapted_logits = adapted_logits - inner_lr * grad
                print(f"    Cycle {n+1}/{N}, Inner step {t+1}/{T}: Adapt loss = {inner_l.item():.6f}")
            
            # Backprop through the cycle for meta-update
            cycle_loss = adaptation_loss(adapted_logits, givens)  # Use adaptation_loss as meta loss (self-supervised)
            meta_opt.zero_grad()
            cycle_loss.backward(retain_graph=True)
            meta_opt.step()
            
            # Update state for next cycle
            probs = F.softmax(adapted_logits, dim=-1).permute(0, 3, 1, 2)  # (B,9,9,9) -> (B,9,9,9)
            mask = current_state.sum(dim=1, keepdim=True) == 0  # Update only remaining empties
            current_state = current_state + mask.float() * probs  # Soft fill

        final_logits = model(current_state)  # Final pass after all adaptations
        pred = torch.argmax(final_logits, dim=-1) + 1  # (B,9,9)

        is_correct_puzzle = torch.all(pred == solutions, dim=(1,2))
        correct_puzzles += is_correct_puzzle.sum().item()

        correct_cells += (pred == solutions).sum().item()
        total_cells += pred.numel()

        total_puzzles += puzzles.size(0)

        print(f"  Batch solve rate: {is_correct_puzzle.sum().item() / puzzles.size(0):.2%} | Sample pred row 0 (first sample): {pred[0][0].tolist()}")

    puzzle_accuracy = correct_puzzles / total_puzzles
    cell_accuracy = correct_cells / total_cells
    print(f"Evaluation complete. Solve accuracy: {puzzle_accuracy:.2%} | Per-cell accuracy: {cell_accuracy:.2%}")
    model.train()

# --- Unit Tests ---
class TestSudokuModel(unittest.TestCase):
    def test_generator(self):
        board = generate_solved_board()
        self.assertEqual(len(board), 9)
        self.assertEqual(len(board[0]), 9)
        self.assertFalse(any(0 in row for row in board))
        for row in board:
            self.assertEqual(sum(row), sum(range(1,10)))

    def test_puzzle_creation(self):
        solved = generate_solved_board()
        puzzle, sol = create_puzzle(solved, num_clues=30)
        self.assertEqual(sol, solved)
        zeros = sum(row.count(0) for row in puzzle)
        self.assertEqual(zeros, 81 - 30)

    def test_model_forward(self):
        model = HighLevelCNN()
        input = torch.zeros(1,9,9,9)
        output = model(input)
        self.assertEqual(output.shape, (1,9,9,9))

# --- Main Workflow ---
if __name__ == "__main__":
    # Run unit tests
    print("Running automated unit tests...")
    suite = unittest.TestLoader().loadTestsFromTestCase(TestSudokuModel)
    output = io.StringIO()
    runner = unittest.TextTestRunner(stream=output, verbosity=2)
    result = runner.run(suite)
    print(output.getvalue())
    if not result.wasSuccessful():
        print("Unit tests failed. Aborting.")
        sys.exit(1)
    print("Unit tests passed. Proceeding to data generation and training.")

    # Generate data
    print("Generating training dataset (1000 samples)...")
    train_dataset = SudokuDataset(1000)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

    print("Generating test dataset (100 samples)...")
    test_dataset = SudokuDataset(100)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

    # Verify diversity
    _, sample_solution1, _ = train_dataset[0]
    _, sample_solution2, _ = train_dataset[1]
    print("Sample solution 1 first row:", sample_solution1[0].tolist())
    print("Sample solution 2 first row:", sample_solution2[0].tolist())

    # Initialize model
    model = HighLevelCNN()

    # Train with N x T
    print("Starting training...")
    train(model, train_loader, epochs=100, N=3, T=5)

    # Evaluate with meta-adapt
    print("Starting evaluation...")
    evaluate_meta_adapt(model, test_loader)

produces this error:

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

/tmp/ipython-input-2327673764.py in <cell line: 0>()
    309     # Evaluate with meta-adapt
    310     print("Starting evaluation...")
--> 311     evaluate_meta_adapt(model, test_loader)



3 frames


/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py in _engine_run_backward(t_outputs, *args, **kwargs)
    821         unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    822     try:
--> 823         return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    824             t_outputs, *args, **kwargs
    825         )  # Calls into the C++ engine to run the backward pass


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [20736, 729]], which is output 0 of AsStridedBackward0, is at version 34; expected version 33 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Can anyone help me fix the error assuming the evaluation is implemented correctly. To my thinking it is as I had to strongly request that it produce the correct gradient flow through N x T steps. The logic should be the same as the training loop and the gradients must flow and update correctly.

This is my first big discovery! Please help!

Thank you

problem resolved model works :wink: