How do I validate with PyTorch DistributedDataParallel

Hi,

Do I need to validate the model after each epoch while using DDP only on rank = 0, or should it be on all processes? With the current code I get validation loss on all processes. Which is the correct way to validate with DDP? This is my code: (I am still prototyping so validation and training loop have the same dataloader).

def validate(model, train_loader):

# Validate the model.
model.eval()

validation_loss = 0.0
with torch.no_grad():
    for i, (images, labels,  wsi_id) in enumerate(train_loader):

        # Pass data and label to the GPU.
        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        images = torch.squeeze(images).contiguous()

        # Forward pass with autocasting/mixed-precision,
        # GradScaler is not needed for inference.
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            logits, Y_prob, Y_hat, _ = model(images)
            vloss = loss_fn(logits, labels)
        validation_loss += vloss

    print(f'Validation loss {validation_loss / (i + 1):.4f}')

if name == ‘main’:

parser = argparse.ArgumentParser(description="Don't worry be happy!")
parser.add_argument('--df_path', type=str, help='', default='')
parser.add_argument('--shard_path', type=str, help='', default='')
parser.add_argument('--model_output_path', type=str, help='Model checkpoint output path', default='')
parser.add_argument('--batch_size', type=int, help='Batch size per GPU. Default batch_size is in bag-level (1 WSI)', default=1)
parser.add_argument('--epochs', type=int, help='Number of training epochs', default=1000000)
parser.add_argument('--lr', type=float, help='learning rate (default: 1.0)', default=0.0001)
parser.add_argument('--num_workers', type=int, help='Number of dataloader processes for each GPU', default=8)
parser.add_argument('--seed', type=int, help='random seed (default: 1)', default=1)
args = parser.parse_args()

df_path = args.df_path
shard_path = args.shard_path
model_output_path = args.model_output_path
epochs = args.epochs
batch_size = args.batch_size
num_workers = args.num_workers
lr = args.lr

shard_df = pd.read_csv(df_path)
seed_torch(args.seed)

# Env variables
local_rank = int(os.environ['LOCAL_RANK'])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])

# Initialize PyTorch DistributedDataParallel.
init_distributed(local_rank, global_rank, world_size)

# Initialize train/valid data loaders.
train_loader = get_dataloader(shard_path, shard_df)

# Load the model.
model = Network().cuda()
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=rank)

# Define loss function (criterion) and optimizer.
loss_fn = nn.CrossEntropyLoss().cuda()
optimizer = optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.0005)
# Create a GradScaler for training with mixed precision.
scaler = GradScaler()

# Start training.
start_train = time.time()
for epoch in range(epochs):
    start_time = time.time()
    train_loader.sampler.set_epoch(epoch)
    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    train(model, train_loader)
    dist.barrier()
    print(f"Epoch took: {(time.time() - start_time) / 60:.2f} mins")

    model.train(False)
    validate(model, train_loader)

epoch_time = (time.time() - start_train) / 60
print(f"Training took: {epoch_time:.2f} mins")

cleanup()

if is_main_process:
    # Save GradScaler values for continued training of the model with
    # amp mixed precision.
    checkpoint = {"model": model.state_dict(),
                  "optimizer": optimizer.state_dict(),
                  "scaler": scaler.state_dict()}
    # All processes should see same parameters as they all start from same
    # random parameters and gradients are synchronized in backward passes.
    # Therefore, saving it in the main process is sufficient.
    save_on_master(checkpoint, model_output_path + "/model_prototype.pt")
1 Like

You could use the validate method from the ImageNet example as a template and adapt it to your use case which uses multiple devices.

1 Like

Thank you @ptrblck for your answer. I have implemented validation as you suggested below. Do you think this is right? One question is that did I use properly dist.barrier() and dist.all_reduce in this case, and also should the averaged validation loss be calculated as below:

I would really really appreciate your feedback on this.

def validate(valid_loader, model, loss_fn, optimizer, scaler, world_size,):

    model.eval()

    validation_loss = 0.0
    with torch.no_grad():
        for i, (images, labels,  wsi_id) in enumerate(valid_loader):

            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            images = torch.squeeze(images)

            with torch.autocast(device_type="cuda", dtype=torch.float16):
                logits, Y_prob, Y_hat, _ = model(images)
                vloss = loss_fn(logits, labels)
            validation_loss += vloss

        # Wait until all processes are finished with the evaluation and
        # reduce (operation: SUM) the loss value across all processes, so that
        # each of them get the same final value.
        dist.barrier()
        dist.all_reduce(validation_loss, dist.ReduceOp.SUM, async_op=False)
        # Get averaged validation loss from all processes.
        validation_loss = validation_loss / ((i + 1) * world_size)

        print(f"Validation loss {validation_loss:.4f}")

I call this method here:

for epoch in range(max_epochs):

        train_loader.sampler.set_epoch(epoch)
        model.train(True)
        train_one_epoch(train_loader, model, loss_fn, optimizer, scaler,
                        local_rank, world_size, epoch, max_epochs,)

        # Validate one epoch.
        model.train(False)
        validate(valid_loader, model, loss_fn, optimizer, scaler, world_size,)

1 Like

Your code looks alright, but note the special handling of the last batch of the validation DataLoader in case it’s smaller than the defined batch_size.
In particular, the val_loader was created with drop_last=True and this check:

if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)):

is used to call into run_validate again with the remaining samples.

1 Like

Hi @ptrblck I did the implementation as below. Do you think this makes sense?
I would really really appreciate your feedback on this.

def validate(valid_loader, model, loss_fn, optimizer, scaler, world_size,
            early_stopping, epoch, results_path, batch_size, num_workers,):

    def run_validate(valid_loader):

        validation_loss = 0.0
        count = torch.zeros(1, dtype=torch.float32, device="cuda")

        with torch.no_grad():
            for i, (images, labels,  wsi_id) in enumerate(valid_loader):

                images = images.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)
                images = torch.squeeze(images)

                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    logits, Y_prob, Y_hat, _ = model(images)
                    vloss = loss_fn(logits, labels)

                validation_loss += vloss
                count += 1

        return validation_loss, count

    model.eval()
    valid_loss, count = run_validate(valid_loader)

    if len(valid_loader.sampler) * world_size < len(valid_loader.dataset):
        aux_val_dataset = Subset(valid_loader.dataset, range(len(valid_loader.sampler) * world_size, len(valid_loader.dataset)))
        aux_val_loader = DataLoader(aux_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
        valid_loss_aux, count_aux = run_validate(aux_val_loader)
       # Update the metrics.
       valid_loss = torch.add(valid_loss, valid_loss_aux)
       count = count + count_aux

    dist.barrier()
    dist.all_reduce(valid_loss, dist.ReduceOp.SUM, async_op=False)
    dist.all_reduce(count, dist.ReduceOp.SUM, async_op=False)
    
    avg_val_loss = float((valid_loss / count).detach().cpu())
    print(f"Total validation loss = {avg_val_loss:.4f}")

also would this solution which you proposed @ptrblck be the same as if we would use drop_last=False?

I believe you should all_reduce the losses before running the “rest” of the validation dataset on each worker and add this valid_loss_aux to the allreduced loss.
Besides that your code looks fine.

Yes, using drop_last=False might work during training, but note that this use case is a bit more complicated than in the single-GPU use case.
In DDP you should make sure that each rank has the same amount of batches to avoid a diverged training which could result in hangs.
The DistributedSampler either cuts the dataset or appends samples to make the dataset evenly divisible as seen here.
This is no problem during training as dropping some samples or repeating others would usually not create issues.
However, during validation you often care about the “real” validation loss and accuracy. Repeating samples or dropping them could bias your validation loss, which sounds bad.
This is why the ImageNet example first cuts the dataset (if needed) and executes the validation calculation using the “rest” of the samples in each rank before accumulating the loss.

1 Like

Thank you @ptrblck for the thorough answer about this, very helpful.

I have one more question regarding this scenario:

With the code above and with DistributedSampler drop_last=True (as based on the ImageNet example) with batch_size=1, the “rest” of the data is still being evaluated on all the GPUs e.g. I have 2 GPUs and 5 examples [1,2,3,4,5]. On the first round of training, examples [1,2] are evaluated on GPU1 and [3.4] are evaluated on GPU2 in parallel. On the additional run which runs on the “rest” of the examples, example [5] is evaluated on GPU1 and [5] on GPU2.

Isn’t this similar approach with DistributedSampler drop_last=False? Where the examples would be evaluated as [1,2,3] on GPU1 and [4,5,1] on GPU2. In both cases 1 example is being repeated?

No, I don’t think both approaches would be equal.
The first one uses different samples on each device and reduces the loss, so that both devices have the same loss sum from samples [1, 2, 3, 4]. Afterwards, both devices will calculate the loss of sample 5 and add it to their local and already reduced loss which would then yield the validation loss of all samples. Think about it as: loss of [1, 2, 3, 4] + local loss of 5 = loss of [1, 2, 3, 4, 5].

The second example would repeat sample 1 and then allreduce it creating a bias.

1 Like

I understand now @ptrblck. Thank you so much for you help and guidance on this problem. Much appreciated.