Hi,
I implemented this validation loop for evaluating with DDP PyTorch based on the official tutorial examples/main.py at e4e8da8467d55d28920dbd137261d82255f68c71 · pytorch/examples · GitHub.
I am using drop_last=True in the validation DistributedDataloader, and then adding a final condition to get a last validation run on the remaining dataset (when number of examples is not divisible by my world_size), data that was stripped off by the drop_last=True in the DistributedDataloader.
Can please anyone guide me, if this implementation is ok?
Also is this implementation the same as if I would use drop_last=False in the validation DistributedDataloader?
I would really appreciate your feedback on this.
def validate(valid_loader, model, loss_fn, optimizer, scaler, world_size,
early_stopping, epoch, results_path, batch_size, num_workers,):
"""Validation implemented according to the official PyTorch guidelines
here: https://github.com/pytorch/examples/blob/main/imagenet/main.py"""
def run_validate(valid_loader):
validation_loss = 0.0
count = torch.zeros(1, dtype=torch.float32, device="cuda")
with torch.no_grad():
for i, (images, labels, wsi_id) in enumerate(valid_loader):
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
images = torch.squeeze(images)
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits, Y_prob, Y_hat, _ = model(images)
vloss = loss_fn(logits, labels)
count += 1
validation_loss += vloss
return validation_loss, count
model.eval()
valid_loss, count = run_validate(valid_loader)
# If number of examples is not divisible by the world_size, and drop_last=True on the
# DistributedSampler, call run_validate again to evaluate the remaining samples.
if len(valid_loader.sampler) * world_size < len(valid_loader.dataset):
aux_val_dataset = Subset(valid_loader.dataset, range(len(valid_loader.sampler) * world_size, len(valid_loader.dataset)))
aux_val_loader = DataLoader(aux_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
valid_loss_aux, count_aux = run_validate(aux_val_loader)
# Update the metrics.
valid_loss = torch.add(valid_loss, valid_loss_aux)
count = count + count_aux
dist.barrier()
dist.all_reduce(valid_loss, dist.ReduceOp.SUM, async_op=False)
dist.all_reduce(count, dist.ReduceOp.SUM, async_op=False)
avg_val_loss = float((valid_loss / count).detach().cpu())
print(f"Total validation loss = {avg_val_loss:.4f}")