kwohlfahrt
(Kai Wohlfahrt)
November 20, 2024, 7:13pm
1
The documentation for torch.utils.checkpoint
gives many reasons to prefer the use_reentrant=False
version. However, it doesn’t say when use_reentrant=True
is required.
I am trying to migrate a model to use_reentrant=False
, but see errors like the example below. The same model runs successfully with use_reentrant=True
, but I’m not sure what causes the incompatibility. I haven’t yet been able to reduce this in a minimal example, and ss far as I know, we aren’t doing anything weird with the backwards pass.
tensor at position 160:
saved metadata: {'shape': torch.Size([512, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([128, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
What might cause checkpoint(..., use_reentrant=False)
to fail when checkpoint(..., use_reentrant=True)
works?
kwohlfahrt
(Kai Wohlfahrt)
December 3, 2024, 2:29pm
2
I’ve now found a minimal reproduction for this, and it seems to be an interaction between mixed-precision and checkpointing. I’ve reported the issue upstream: CheckpointError with checkpoint(..., use_reentrant=False) & autocast() · Issue #141896 · pytorch/pytorch · GitHub
For context, this is sufficient to reproduce the issue:
import torch
from torch.utils.checkpoint import checkpoint
class PairStack(torch.nn.Module):
def __init__(self, size: int):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.LayerNorm(size),
torch.nn.Linear(size, size)
)
def forward(self, x):
# Issue also occurs if we unconditionally checkpoint()
if torch.is_grad_enabled():
return checkpoint(self.block, x, use_reentrant=False, debug=True)
else:
return self.block(x)
class Mod(torch.nn.Module):
def __init__(self, size: int):
super().__init__()
self.stack = PairStack(size)
self.linear = torch.nn.Linear(size, size)
def forward(self, x):
with torch.set_grad_enabled(False):
x = self.stack(x)
x = self.stack(x)
return self.linear(x)
def main():
device = torch.device("cpu")
size = 64
m = Mod(size=size).to(device)
x = torch.linspace(0, 1, 2 * 3 * size).reshape(2, 3, size).to(device)
with torch.autocast(device.type, dtype=torch.bfloat16):
output = m(x)
loss = output.sum()
loss.backward()
if __name__ == "__main__":
main()
Does anyone have advice on where to start fixing this?