Runtime Error only occurs with torch.compile()

I’m getting a Runtime error that only shows up when I use torch.compile() (in both nightly and stable releases). In eager mode the script runs without any error.

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Here is a minimal script that reproduces the error. If you comment out the line net = torch.compile(net) then it will run without error:

import torch
from torch import nn
import torch.nn.functional as F


def device():
    return "cuda"


class WidthBuffer:
    def __init__(self):
        self._values = None

    def size(self):
        if self._values is None:
            return 0
        return self._values.shape[3]

    def add(self, new_values):
        if self._values is None:
            self._values = new_values
        else:
            self._values = torch.cat((self._values, new_values), 3)

    def remove(self, size):
        assert size <= self._values.shape[3]
        removed = self._values[:, :, :, :size]
        self._values = self._values[:, :, :, size:]
        return removed

    def detach(self):
        if self._values is not None:
            self._values = self._values.detach()


class FinalConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.buffer = WidthBuffer()
        self.conv = nn.Conv2d(1, 4, 3, 1, 1, device=device())

    def detach(self):
        self.buffer.detach()

    def forward(self, inputs, higher):
        if higher is not None:
            self.buffer.add(higher)
        x = self.buffer.remove(inputs.shape[3])
        x = self.conv(x)
        return x


class Layer(nn.Module):
    def __init__(self):
        super().__init__()
        self.values = None
        self.conv_1 = nn.Conv2d(4, 1, 3, 2, device=device(), padding=1)
        self.conv_2 = nn.ConvTranspose2d(1, 1, 2, 2, padding=0, device=device())
        self.buffer = WidthBuffer()
        self.input_buffer = WidthBuffer()
        self.input_shape = None

    def compute1(self, inputs):
        self.input_buffer.add(inputs)
        self.input_shape = inputs.shape
        size = self.input_buffer.size()
        remove_count = (size // 2) * 2
        if remove_count:
            inputs = self.input_buffer.remove(remove_count)
            self.values = self.conv_1(inputs)
        else:
            empty_values_shape = (inputs.shape[0], 1, inputs.shape[2] // 2, 0)
            self.values = torch.zeros(empty_values_shape, device=inputs.device)

    def compute2(self):
        out = self.conv_2(self.values)
        self.values = None
        return out

    def detach(self):
        self.buffer.detach()
        self.input_buffer.detach()


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([Layer()])
        self.final_conv = FinalConv()
        self.requires_start_init = True

    def forward(self, input_idxs, process_size):
        if process_size and input_idxs.shape[3] != process_size:
            padding = process_size - input_idxs.shape[3]
            input_idxs = F.pad(input_idxs, (0, padding))
        values = original_inputs = F.one_hot(input_idxs.squeeze(1), 4).permute(0, 3, 1, 2).float()
        if self.requires_start_init:
            self.start_init(values.shape, values.device)
        self.layers[0].compute1(values)
        x = self.layers[0].compute2()
        x = self.final_conv(original_inputs, x)
        return x

    def start_init(self, input_shape, device):
        inputs = values = torch.zeros(input_shape[:3] + (1,), device=device)
        for layer in self.layers:
            values = layer.compute1(values)
            layer.values = torch.zeros(layer.values.shape[:3] + (1,), device=device)
        self.final_conv(inputs, self.layers[0].compute2())
        self.requires_start_init = False

    def detach(self):
        self.layers[0].detach()
        self.final_conv.detach()


def train_net(net):
    optimizer = torch.optim.Adam(net.parameters())
    process_size = 10
    for train_idx in range(10):
        x = torch.rand((1, 6, 20), device=device())
        for width_idx in range(0, x.shape[-1], process_size):
            inputs = x[:, :, width_idx: width_idx + process_size].unsqueeze(1).long()
            out = net(inputs, process_size)
            loss = out.sum()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            net.detach()
            print(f"Loss={loss.item():0.2f}")


net = Model()
net = torch.compile(net)  # Works when this line is commented out
train_net(net)

My environment:

$ python3 -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.1.0.dev20230429+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: glibc-2.35

Python version: 3.10.6 (main, Mar 10 2023, 10:55:28) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.19.0-41-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA GeForce GTX TITAN X

Nvidia driver version: 525.105.17
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   48 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          32
On-line CPU(s) list:             0-31
Vendor ID:                       AuthenticAMD
Model name:                      AMD Ryzen Threadripper PRO 5955WX 16-Cores
CPU family:                      25
Model:                           8
Thread(s) per core:              2
Core(s) per socket:              16
Socket(s):                       1
Stepping:                        2
Frequency boost:                 enabled
CPU max MHz:                     7031.2500
CPU min MHz:                     1800.0000
BogoMIPS:                        7985.23
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                  AMD-V
L1d cache:                       512 KiB (16 instances)
L1i cache:                       512 KiB (16 instances)
L2 cache:                        8 MiB (16 instances)
L3 cache:                        64 MiB (2 instances)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-31
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] pytorch-triton==2.1.0+7d1a95b046
[pip3] torch==2.1.0.dev20230429+cu118
[pip3] torchaudio==2.1.0.dev20230429+cu118
[pip3] torchvision==0.16.0.dev20230429+cu118
[pip3] triton==2.0.0
[conda] Could not collect

I believe you might be running into this issue, so feel free to comment on it with your use case.

I’m not sure it’s the same issue. In the train loop for this code backward() is only called once, not twice. And afterwards net.detach() ensures that no tensors in any buffers have gradients attached.

Or maybe it is the same issue. Adding self.detach() at the end of start_init() solves the problem.

Can you see if you’re still the hitting issue in a recent nightly? There was recent fix here: AOTAutograd: fix 'Trying to backward through the graph a second time' error by bdhirsh · Pull Request #98960 · pytorch/pytorch · GitHub (it won’t fix all instances of this problem, but it should fix some common ones)