DDP with learning rate schedulers

I would like to use a learning rate scheduler to (potentially) adapt the learning rate after each epoch, depending on a metric gathered from the validation dataset. However, I am not sure how to use the ReduceLRonPlateau learning rate scheduler within DDP. Afaik, the best practice is to calculate the validation metric using the GPU on rank 0. But how do I then communicate the new learning rate to the other GPUs? My current train function looks like this:

for epoch in range(5):
    train_sampler.set_epoch(epoch)
    val_sampler.set_epoch(epoch)
    self.optimizer.zero_grad()
    for features, labels in train_loader:
        features, labels = features.to(self.device), labels.to(self.device)

        with torch.cuda.amp.autocast():
            preds = self.model(features)
            loss = self.criterion(preds, labels)

        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()


    if rank == 0:
        with torch.no_grad():
            ... # get validation metric
    if rank == 0:
        self.scheduler.step(-(epoch//3))  # dummy for demonstration purposes
    dist.barrier()
    current_lr = float(self.optimizer.param_groups[0]['lr'])
    print(f"LR on rank {rank} in epoch {epoch + 1}: {current_lr}")

The scheduler is defined like this:

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=0)

The command on the cluster is the following:

torchrun --nproc_per_node=2 filename.py

The output is the following:

LR on rank 1 in epoch 1: 0.1
LR on rank 0 in epoch 1: 0.1
LR on rank 1 in epoch 2: 0.1
LR on rank 0 in epoch 2: 0.01
LR on rank 1 in epoch 3: 0.1
LR on rank 0 in epoch 3: 0.001
LR on rank 1 in epoch 4: 0.1
LR on rank 0 in epoch 4: 0.001
LR on rank 1 in epoch 5: 0.1
LR on rank 0 in epoch 5: 0.001

I would like that the learning rates on both GPUs are the same. In this case, I expect that the third line of the printout already has the learning rate for rank 1 set to 0.01.

For reproducibility, here is the full code:

import torch
import torch.nn as nn
import os
import torch.nn.functional as F
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler
from torch.utils.data import Dataset, DataLoader, TensorDataset


class Model(object):
    def __init__(self, model, device, optimizer, scaler, scheduler):
        self.optimizer = optimizer
        self.device = device
        self.scaler = scaler
        self.scheduler = scheduler
        self.criterion = torch.nn.MSELoss().to(self.device)
        self.model = model.to(self.device)


    def train(self, train_loader, train_sampler, val_loader, val_sampler, rank):
        for epoch in range(5):
            train_sampler.set_epoch(epoch)
            val_sampler.set_epoch(epoch)
            self.optimizer.zero_grad()

            for features, labels in train_loader:
                features, labels = features.to(self.device), labels.to(self.device)

                with torch.cuda.amp.autocast():
                    preds = self.model(features)
                    loss = self.criterion(preds, labels)

                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()


            if rank == 0:
                with torch.no_grad():
                    running_loss = 0
                    for features, labels in val_loader:
                        features, labels = features.to(self.device), labels.to(self.device)
                        preds = self.model(features)
                        loss = self.criterion(preds, labels)
                        running_loss += loss.item()
            if rank == 0:
                self.scheduler.step(-(epoch//3))  # dummy for demonstration purposes
            dist.barrier()
            current_lr = float(self.optimizer.param_groups[0]['lr'])
            print(f"LR on rank {rank} in epoch {epoch + 1}: {current_lr}")


def train_pipeline():
    dist.init_process_group(backend='nccl')
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    local_rank = rank % torch.cuda.device_count()
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    batch_size = 32


    num_features = 100
    train_inps = torch.rand(size=(320, num_features), dtype=torch.float32)
    train_tgts = torch.rand(size=(320, 1), dtype=torch.float32)
    train_dataset = TensorDataset(train_inps, train_tgts)
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=5, pin_memory=True, drop_last=True)

    val_inps = torch.rand(size=(160, num_features), dtype=torch.float32)
    val_tgts = torch.rand(size=(160, 1), dtype=torch.float32)
    val_dataset = TensorDataset(val_inps, val_tgts)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler, num_workers=5, pin_memory=True, drop_last=True)

    model = nn.Sequential(nn.Linear(num_features, 10), nn.ReLU(), nn.Linear(10, 1))
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    scaler = torch.cuda.amp.GradScaler()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=0)
    my_model = Model(model=model, device=device, optimizer=optimizer, scaler=scaler, scheduler=scheduler)
    my_model.train(train_loader, train_sampler, val_loader, val_sampler, rank)


if __name__ == "__main__":
    train_pipeline()

Hi @yaiza612, you would need to do some communication (e.g. torch.distributed.broadcast) to transfer the learning rate across your GPUs, something like:

current_lr = self.optimizer.param_groups[0]['lr']
#### NEW CODE
lr_tensor = torch.tensor(current_lr, device=self.device)
dist.broadcast(lr_tensor, src=0)
for param_group in self.optimizer.param_groups:
    param_group['lr'] = lr_tensor.item()
#### end
print(f"LR on rank {rank} in epoch {epoch + 1}: {lr_tensor.item()}")