Validation with DDP Pytorch


I implemented this validation loop for evaluating with DDP PyTorch based on the official tutorial examples/ at e4e8da8467d55d28920dbd137261d82255f68c71 · pytorch/examples · GitHub.

I am using drop_last=True in the validation DistributedDataloader, and then adding a final condition to get a last validation run on the remaining dataset (when number of examples is not divisible by my world_size), data that was stripped off by the drop_last=True in the DistributedDataloader.

Can please anyone guide me, if this implementation is ok?
Also is this implementation the same as if I would use drop_last=False in the validation DistributedDataloader?

I would really appreciate your feedback on this.

def validate(valid_loader, model, loss_fn, optimizer, scaler, world_size,
            early_stopping, epoch, results_path, batch_size, num_workers,):
    """Validation implemented according to the official PyTorch guidelines

    def run_validate(valid_loader):

        validation_loss = 0.0
        count = torch.zeros(1, dtype=torch.float32, device="cuda")

        with torch.no_grad():
            for i, (images, labels,  wsi_id) in enumerate(valid_loader):

                images = images.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)
                images = torch.squeeze(images)

                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    logits, Y_prob, Y_hat, _ = model(images)
                    vloss = loss_fn(logits, labels)

                count += 1
                validation_loss += vloss

        return validation_loss, count

    valid_loss, count = run_validate(valid_loader)

    # If number of examples is not divisible by the world_size, and drop_last=True on the
    # DistributedSampler, call run_validate again to evaluate the remaining samples.
    if len(valid_loader.sampler) * world_size < len(valid_loader.dataset):
        aux_val_dataset = Subset(valid_loader.dataset, range(len(valid_loader.sampler) * world_size, len(valid_loader.dataset)))
        aux_val_loader = DataLoader(aux_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
        valid_loss_aux, count_aux = run_validate(aux_val_loader)
        # Update the metrics.
        valid_loss = torch.add(valid_loss, valid_loss_aux)
        count = count + count_aux

    dist.all_reduce(valid_loss, dist.ReduceOp.SUM, async_op=False)
    dist.all_reduce(count, dist.ReduceOp.SUM, async_op=False)

    avg_val_loss = float((valid_loss / count).detach().cpu())
    print(f"Total validation loss = {avg_val_loss:.4f}")

Double post from here with a follow up.