Checkpoint() doesn't work with use_reentrant=False

The documentation for torch.utils.checkpoint gives many reasons to prefer the use_reentrant=False version. However, it doesn’t say when use_reentrant=True is required.

I am trying to migrate a model to use_reentrant=False, but see errors like the example below. The same model runs successfully with use_reentrant=True, but I’m not sure what causes the incompatibility. I haven’t yet been able to reduce this in a minimal example, and ss far as I know, we aren’t doing anything weird with the backwards pass.

tensor at position 160:
saved metadata: {'shape': torch.Size([512, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([128, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}

What might cause checkpoint(..., use_reentrant=False) to fail when checkpoint(..., use_reentrant=True) works?

I’ve now found a minimal reproduction for this, and it seems to be an interaction between mixed-precision and checkpointing. I’ve reported the issue upstream: CheckpointError with checkpoint(..., use_reentrant=False) & autocast() · Issue #141896 · pytorch/pytorch · GitHub

For context, this is sufficient to reproduce the issue:

import torch
from torch.utils.checkpoint import checkpoint


class PairStack(torch.nn.Module):
    def __init__(self, size: int):
        super().__init__()
        self.block = torch.nn.Sequential(
            torch.nn.LayerNorm(size),
            torch.nn.Linear(size, size)
        )

    def forward(self, x):
        # Issue also occurs if we unconditionally checkpoint()
        if torch.is_grad_enabled():
            return checkpoint(self.block, x, use_reentrant=False, debug=True)
        else:
            return self.block(x)


class Mod(torch.nn.Module):
    def __init__(self, size: int):
        super().__init__()
        self.stack = PairStack(size)
        self.linear = torch.nn.Linear(size, size)

    def forward(self, x):
        with torch.set_grad_enabled(False):
            x = self.stack(x)
        x = self.stack(x)
        return self.linear(x)


def main():
    device = torch.device("cpu")
    size = 64

    m = Mod(size=size).to(device)
    x = torch.linspace(0, 1, 2 * 3 * size).reshape(2, 3, size).to(device)

    with torch.autocast(device.type, dtype=torch.bfloat16):
        output = m(x)

    loss = output.sum()
    loss.backward()

if __name__ == "__main__":
    main()

Does anyone have advice on where to start fixing this?