I am investigating speeding up training using the CUDA graphs pytorch API. My training loop functions correctly with AMP/autocast disabled but with it enabled, I get an error on the scaler.step(optimiser) instruction.
Below is my training loop:
warmup_loader = get_warmup_loader(args, rank)
device = get_device(args)
use_amp = args.get('amp').get('enabled')
if is_main_node(rank):
print(f'\nWarming up')
# CUDA Graphs Warming Up Iteration
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
model = get_model(args)
model.train()
for batch, (images, labels) in enumerate(tqdm(warmup_loader, disable=not is_main_node(rank))):
images = images.to(device, memory_format=get_memory_format(args))
labels = labels.to(device, memory_format=get_memory_format(args))
with autocast(enabled=use_amp):
y_pred = model(images)['out']
loss = criterion(y_pred, labels)
if use_amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
torch.cuda.current_stream().wait_stream(s)
# CUDA Graphs Capture
capture_input = torch.empty(images.shape, device=device)
capture_target = torch.empty(labels.shape, device=device)
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
with autocast(enabled=use_amp):
capture_y_pred = model(capture_input)['out']
capture_loss = criterion(capture_y_pred, capture_target)
if use_amp:
scaler.scale(capture_loss).backward()
else:
capture_loss.backward()
optimizer.step()
if is_main_node(rank):
print("\nTraining")
# Training on real data
start = time.time()
for epoch in range(args.get('hyper-params').get('epochs')):
for batch, (images, labels) in enumerate(tqdm(train_loader, disable=not is_main_node(rank))):
capture_input.copy_(images)
capture_target.copy_(labels)
g.replay()
if use_amp:
scaler.step(optimizer)
scaler.update()
end = time.time()
if is_main_node(rank):
print(f"\nTraining took: {end - start:.2f}s")
I get the following error:
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
I have been following the documentation found here: CUDA semantics — PyTorch master documentation.
From my understanding the scaler/optimiser should know whether to skip or not based on inf/NaN?
Any help would be great.