Error when backpropagating through a compiled PyTorch module multiple times

I’m working on a PyTorch project where I have a custom module called that applies axial rotatry position embeddings to tensors. I’ve implemented both a regular and a compiled version of this module using torch.compile. When I try to backpropagate through the compiled version multiple times, I encounter the following error:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

However, the regular (uncompiled) version of the module works fine with multiple backpropagation calls.

Here’s the relevant code:

import math
from functools import reduce

import torch
from einops import rearrange
from torch import nn


def bounding_box(
    h: int, w: int, pixel_ar: float = 1.0
) -> tuple[float, float, float, float]:
    # Compute the adjusted aspect ratio based on the pixel aspect ratio
    ar = w / (h * pixel_ar)

    # Compute the bounding box
    h_bounds = (-1.0 / ar, 1.0 / ar) if ar > 1.0 else (-1.0, 1.0)
    w_bounds = (-ar, ar) if ar < 1.0 else (-1.0, 1.0)

    return h_bounds + w_bounds


def centered_linspace(
    start: float,
    end: float,
    steps: int,
    *,
    dtype: torch.dtype = None,
    device: torch.device = None,
) -> torch.Tensor:
    edges = torch.linspace(start, end, steps + 1, dtype=dtype, device=device)
    # Compute the midpoint between each pair of edges
    return (edges[:-1] + edges[1:]) / 2


def make_axial_positions(
    h: int,
    w: int,
    pixel_ar: float = 1.0,
    align_corners: bool = False,
    dtype: torch.dtype = None,
    device: torch.device = None,
) -> torch.Tensor:
    h_min, h_max, w_min, w_max = bounding_box(h, w, pixel_ar)

    # If align_corners is set to True, the grid will include the corners of the bounding box
    # Otherwise, the grid boundaries will include the centers of the pixels
    linspace_fn = torch.linspace if align_corners else centered_linspace
    h_grid = linspace_fn(h_min, h_max, h, dtype=dtype, device=device)
    w_grid = linspace_fn(w_min, w_max, w, dtype=dtype, device=device)

    # Create a grid of positions
    h_positions, w_positions = torch.meshgrid(h_grid, w_grid, indexing="ij")

    return torch.stack((h_positions, w_positions), dim=-1)


def apply_axial_rope(
    x: torch.Tensor, theta: torch.Tensor, conjugate: bool = False
) -> None:
    # Ensure the operations are performed in float32
    dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))

    # Ensure that the dimensions of x and theta are compatible
    dim = theta.shape[-1]
    assert dim * 2 <= x.shape[-1], f"x must have at least {2 * dim} channels"

    # Extract tensor components and ensure they have the correct dtype
    x_1, x_2, x_3 = x[..., :dim], x[..., dim : dim * 2], x[..., dim * 2 :]
    x_1, x_2, theta = map(lambda t: t.to(dtype), (x_1, x_2, theta))

    # Compute the rotation angles
    cos, sin = theta.cos(), theta.sin()
    sin = -sin if conjugate else sin

    # Rotate the tensors
    x_1 = (cos * x_1 - sin * x_2).to(x.dtype)
    x_2 = (sin * x_1 + cos * x_2).to(x.dtype)

    return torch.cat((x_1, x_2, x_3), dim=-1)


class AxialRoPE(nn.Module):
    def __init__(self, dim: int, n_heads: int) -> None:
        super().__init__()

        freqs_min, freqs_max = math.log(math.pi), math.log(10.0 * math.pi)
        freqs = torch.linspace(freqs_min, freqs_max, n_heads * dim // 4 + 1)[:-1].exp()
        freqs = freqs.view(n_heads, dim // 4).contiguous()

        self.register_buffer("freqs", freqs)

    def forward(
        self, q: torch.Tensor, k: torch.Tensor, positions: torch.Tensor
    ) -> torch.Tensor:
        # Compute the rotation angles
        freqs = self.freqs.to(positions.dtype)
        h_theta = positions[..., None, 0:1] * freqs
        w_theta = positions[..., None, 1:2] * freqs
        theta = torch.cat((h_theta, w_theta), dim=-1)
        theta = rearrange(theta, "... x y h d -> ... h x y d")

        # Apply the RoPE to the queries and keys
        q = apply_axial_rope(q, theta)
        k = apply_axial_rope(k, theta)

        return q, k


if __name__ == "__main__":
    with torch.device("cuda:0" if torch.cuda.is_available() else "cpu"):
        positions = make_axial_positions(32, 32)
        q = torch.randn(1, 8, 32, 32, 64, requires_grad=True)
        k = torch.randn(1, 8, 32, 32, 64, requires_grad=True)

        axial_rope = AxialRoPE(64, 8)
        q, k = axial_rope(q, k, positions)

        # Test the backward pass
        q.sum().backward()
        k.sum().backward()

        positions = make_axial_positions(32, 32)
        q = torch.randn(1, 8, 32, 32, 64, requires_grad=True)
        k = torch.randn(1, 8, 32, 32, 64, requires_grad=True)

        axial_rope_compiled = torch.compile(AxialRoPE(64, 8))
        q, k = axial_rope_compiled(q, k, positions)

        # Test the backward pass for the compiled version
        q.sum().backward()
        k.sum().backward()

Here are my environment details:

PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 23.10 (x86_64)
GCC version: (Ubuntu 13.2.0-4ubuntu3) 13.2.0
Clang version: Could not collect
CMake version: version 3.20.3
Libc version: glibc-2.38

Python version: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.5.0-27-generic-x86_64-with-glibc2.38
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 550.54.15
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      39 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             32
On-line CPU(s) list:                0-31
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Core(TM) i9-14900KF
CPU family:                         6
Model:                              183
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          1
Stepping:                           1
CPU(s) scaling MHz:                 20%
CPU max MHz:                        6000.0000
CPU min MHz:                        800.0000
BogoMIPS:                           6374.40
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          896 KiB (24 instances)
L1i cache:                          1.3 MiB (24 instances)
L2 cache:                           32 MiB (12 instances)
L3 cache:                           36 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] clip-anytorch==2.6.0
[pip3] dctorch==0.1.2
[pip3] mypy-extensions==1.0.0
[pip3] natten==0.15.1+torch220cu121
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==1.9.5
[pip3] rotary-embedding-torch==0.5.3
[pip3] torch==2.2.1
[pip3] torchaudio==2.2.1
[pip3] torchdiffeq==0.2.3
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.3.2
[pip3] torchsde==0.2.6
[pip3] torchvision==0.17.1
[pip3] triton==2.2.0
[conda] clip-anytorch             2.6.0                    pypi_0    pypi
[conda] dctorch                   0.1.2                    pypi_0    pypi
[conda] natten                    0.15.1+torch220cu121          pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] pytorch-lightning         1.9.5                    pypi_0    pypi
[conda] rotary-embedding-torch    0.5.3                    pypi_0    pypi
[conda] torch                     2.2.1                    pypi_0    pypi
[conda] torchaudio                2.2.1                    pypi_0    pypi
[conda] torchdiffeq               0.2.3                    pypi_0    pypi
[conda] torchinfo                 1.8.0                    pypi_0    pypi
[conda] torchmetrics              1.3.2                    pypi_0    pypi
[conda] torchsde                  0.2.6                    pypi_0    pypi
[conda] torchvision               0.17.1                   pypi_0    pypi
[conda] triton                    2.2.0                    pypi_0    pypi

Hi!

This is actually a semi-fundamental issue with torch.compile. In particular:

(1) in eager mode: this works, because the backward graphs for q and k are completely independent. You can therefore .backward() on each of them separately.

(2) in compile: By compiling your entire AxialRoPE module into a single compiled region, torch.compile is free to try to fuse the backward compute of q and v together. This the “compiled” version of the backward graph to show up as one giant node in the autograd graph, instead of a bunch of tiny autograd noes that can be run independently.

So the tldr here is that the benefits of compilation / fusion mean that we lose some of the fine-grained detail of the autograd graph.

You effectively have two options here:

Option 1: Run with

(q + k).sum().backward()

To perform all backward compute in a single call, This will allow us to compute gradients for q and v simulatenously (especially good for perf if torch.compile has fused their backward compute together)

Option 2: run with:

q.sum().backward(retain_graph=True)
k.sum().backward()

This will allow autograd to keep the compiled backward graph around after the first backward call (at the cost of extra memory). In particular; the compiled backward code is allowed to fuse backward compute for q and v together. So you might end up doing some redundant computation if you call .backward() twice

option 3:

Explicitly separate your logic for q’s compute and v’s compute into separate regions of code, so you can compile them separately. This will effectively force torch.compile not to bundle any of the compute for q’s backward and v’s backward together, so you can call .backward() separately for each output