Using AMP with CUDA Graphs and DDP - Scaler Error

I am investigating speeding up training using the CUDA graphs pytorch API. My training loop functions correctly with AMP/autocast disabled but with it enabled, I get an error on the scaler.step(optimiser) instruction.

Below is my training loop:

warmup_loader = get_warmup_loader(args, rank)
device = get_device(args)
use_amp = args.get('amp').get('enabled')

if is_main_node(rank):
    print(f'\nWarming up')

# CUDA Graphs Warming Up Iteration

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    model = get_model(args)
    model.train()

    for batch, (images, labels) in enumerate(tqdm(warmup_loader, disable=not is_main_node(rank))):
        images = images.to(device, memory_format=get_memory_format(args))
        labels = labels.to(device, memory_format=get_memory_format(args))

        with autocast(enabled=use_amp):
            y_pred = model(images)['out']
            loss = criterion(y_pred, labels)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        optimizer.zero_grad(set_to_none=True)

torch.cuda.current_stream().wait_stream(s)

# CUDA Graphs Capture

capture_input = torch.empty(images.shape, device=device)
capture_target = torch.empty(labels.shape, device=device)

g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
    with autocast(enabled=use_amp):
        capture_y_pred = model(capture_input)['out']
        capture_loss = criterion(capture_y_pred, capture_target)

    if use_amp:
        scaler.scale(capture_loss).backward()
    else:
        capture_loss.backward()
        optimizer.step()

if is_main_node(rank):
    print("\nTraining")

# Training on real data

start = time.time()
for epoch in range(args.get('hyper-params').get('epochs')):
    for batch, (images, labels) in enumerate(tqdm(train_loader, disable=not is_main_node(rank))):
        capture_input.copy_(images)
        capture_target.copy_(labels)
        g.replay()
        if use_amp:
            scaler.step(optimizer)
            scaler.update()
end = time.time()

if is_main_node(rank):
    print(f"\nTraining took: {end - start:.2f}s")

I get the following error:

assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

I have been following the documentation found here: CUDA semantics — PyTorch master documentation.
From my understanding the scaler/optimiser should know whether to skip or not based on inf/NaN?

Any help would be great.

@ptrblck Could you help with this issue?