DDP evaluation / tensorboard logging

I’ve successfully set up DDP with the pytorch tutorials, but I cannot find any clear documentation about testing/evaluation. I want to do 2 things:

  1. Track train/val loss in tensorboard
  2. Evaluate my model straight after training (in same script).

However, both of these fail: (1) consistently gives me 2 entries per epoch, even though I do not use a distributed sampler for the validation loss and it should only execute if gpu_id==0, (2) evaluation on the test set doesn’t even happen, or it is excessively slow. The print statement inside the if statement checking the GPU id never gets printed, even though I use the exact same if statement earlier for snapshot saving/val loss calculation and there it enters the if statement normally.
Only the train dataloader uses a distributed sampler.

Now I can live with the strange tensorboard graphs, but I would really like to be able to evaluate my model in the same script.

here’s my code:

# ython -m torch.distributed.launch main_train_psnr.py --opt options/swinir/train_swinir_sr_classical.json

import os

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

from baselines.swinir.utils import EarlyStopper, calculate_psnr, calculate_ssim
from baselines.swinir.dataloader import get_swinir_dataloaders
from baselines.swinir.swinir import SwinIR


DATADIR = os.environ["DATADIR"]
DATASETSDIR = os.environ["DATASETSDIR"]
RESULTSDIR = os.environ["RESULTSDIR"]
SCALE = 2


def ddp_setup():
    init_process_group(backend='nccl')


class Trainer:
    def __init__(self, model: torch.nn.Module, train_data: DataLoader, val_data: DataLoader, test_data: DataLoader, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler._LRScheduler, save_every: int, snapshot_path: str, tensorboard_path: str) -> None:
        self.gpu_id = int(os.environ["LOCAL_RANK"])
        self.model = model.to(self.gpu_id)
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.save_every = save_every
        self.snapshot_path = snapshot_path
        self.epochs_run = 0
        self.batch_size = 16
        if self.gpu_id == 0:
            self.logger = SummaryWriter(tensorboard_path)
        else:
            self.logger = None
        self.early_stopper = EarlyStopper(patience=10)
        if os.path.exists(snapshot_path):
            print("Loading snapshot")
            self._load_snapshot(snapshot_path)

        self.model = DDP(self.model, device_ids=[
                         self.gpu_id], find_unused_parameters=True)

    def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.gpu_id}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        self.early_stopper = snapshot["EARLY_STOPPER"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    def _run_batch(self, source, targets, epoch):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.l1_loss(output, targets)
        if self.gpu_id == 0:
            self.logger.add_scalar("Loss/train", loss, epoch)
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()

    def _validate_batch(self, val_source, val_targets):
        with torch.no_grad():
            output = self.model(val_source)
            return F.l1_loss(output, val_targets).item()

    def _get_epoch_val_loss(self):
        val_loss = 0
        for source, targets in self.val_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            val_loss += self._validate_batch(source, targets)
        return val_loss

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            self._run_batch(source, targets, epoch)
            if self.gpu_id == 0:
                self.model.eval()
                val_loss = self._get_epoch_val_loss()
                self.logger.add_scalar("Loss/val", val_loss, epoch)
                self.model.train()
                # if self.early_stopper.early_stop(val_loss):
                #     break

    def _save_snapshot(self, epoch):
        snapshot = {
            "MODEL_STATE": self.model.module.state_dict(),
            "EPOCHS_RUN": epoch,
            "EARLY_STOPPER": self.early_stopper,
        }
        torch.save(snapshot, self.snapshot_path)
        print(
            f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")

    def train(self, max_epochs: int):
        for epoch in range(self.epochs_run, max_epochs):
            print(f"[{self.gpu_id}] epoch {epoch}")
            self._run_epoch(epoch)
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)

    def test(self):
        psnr_values = []
        ssim_values = []
        self.model.eval()
        with torch.no_grad():
            for lr, hr in self.test_data:
                print("test batch")
                lr = lr.to(self.gpu_id)
                sr = self.model(lr).to("cpu")

                psnr_value = calculate_psnr(hr, sr).item()
                psnr_values.append(psnr_value)
                ssim_val = calculate_ssim(hr, sr).item()
                ssim_values.append(ssim_val)
                self.model.train()
            psnr, ssim = (sum(psnr_values)/len(psnr_values),
                          sum(ssim_values)/len(ssim_values))
            print(f"PSNR: {psnr} | SSIM: {ssim}")

    def train_test(self, max_epochs: int):
        self.train(max_epochs)
        print("done training")
        if self.gpu_id == 0:
            print(f"[{self.gpu_id}] start eval...")
            self.test()


def load_train_objs(dataset_name: str, image_size: int, scale: int):
    scale = 2
    train_batch_size=16

    print("Load data ...")
    test_loader, valid_loader, train_loader = get_swinir_dataloaders(
        dataset_name, train_batch_size)

    model = SwinIR(img_size=image_size, scale=scale, window_size=8, mlp_ratio=2,
                   embed_dim=180, upsampler='pixelshuffle')  # load your model
    optimizer = Adam(model.parameters(), lr=2e-4, weight_decay=0)
    scheduler = MultiStepLR(
        optimizer, [250000, 400000, 450000, 475000, 500000], 0.5)
    return train_loader, valid_loader, test_loader, model, optimizer, scheduler


def main(save_every: int, total_epochs: int, snapshot_path: str = "/data1/wasalaj/snapshots/swinir/snapshot.pt", scale: int = 2):
    ddp_setup()
    train_loader, valid_loader, test_loader, model, optimizer, scheduler = load_train_objs(
        "cerrado", 64, scale)
    trainer = Trainer(model, train_loader, valid_loader, test_loader,
                      optimizer, scheduler, save_every, snapshot_path, "/data1/wasalaj/tensorboard/swinir_test")

    trainer.train_test(total_epochs)
    print("done training...")
    destroy_process_group()

Torch.distributed.barrier() hangs in DDP - #6 by Manuel_Alejandro_Dia okay calling self.model.module helps with getting the test evaluation. I will now also check whether it solves the tensorboard issue.

It’s a nope for tensorboard; graphs still look like this (because there’s 2 points per epoch):

Yes, this is because you will need to do some gather operations to sync all final values across GPUs, then you only log your final validation loss or train loss to tensor board from only one process.
Ping me on Thursday and I will be able to help you with some code

that would be great! I thought this would be the case, since this was mentioned in multiple issues, but I don’t think I have found a specific code example for what I need

Search for this method:

torch.distributed.all_reduce(tensor, op=<RedOpType.SUM: 0>, group=None, async_op=False)

You may find some snippets to do what you want. I am currently limited to my phone, sorry for the lack of formatting

right, for me the main question is where to put it. I just experimented with some things, but for example now my validation loss is a scalar so I would need to all_reduce at every batch, which isn’t efficient.

So here is my approach with a dummy example, pay attention to the comments:


def train_dist( gpuid,
                world_size, 
                epochs):
    
    # We make sure the subprocess works on the intended gpu
    torch.cuda.set_device(gpuid)
    device = torch.device('cpu') if not torch.cuda.is_available() else torch.device(f'cuda:{gpuid}')
    
    dist.init_process_group(                                   
        backend='nccl',                                         
        init_method='env://',                                   
        world_size=world_size,                             
        rank=gpuid                                               
    )
   
    # I am ommiting the definition of 
    #    model + optimizer + loss_fn + trainloader + valloader
    # to make the snippet shorter
    # make sure model + optimizer + loss_fn are in the proper
    # device by calling ".to(device)"

    for epoch in range(epochs):
        for bn, batch in enumerate(trainloader):
            images, ground_truth = batch
            
            inputs = inputs.to(device)
            ground_truth = ground_truth.to(device)
            
            output = model(inputs)
            loss = loss_fn(output, ground_truth)
            loss.backward()

            optim.step()
            optim.zero_grad()
     
           dist.barrier()

    # After an epoch, I do a validation pass using a validation method
    # This way I can get the same method across all subprocesses (GPUs)
    # with the same output validation metrics, which later are averaged 
    # for tensorboard loggging

    model.eval()
    val_loss = validation(model, valloader, loss_fn, device)
   
    dist.barrier()

    # Now we will do the reduction pass for all validation losses 
    # across all GPUs. The idea here is that since we may get different
    # val_loss values across GPUs since they validate with different slices
    # of the data, we need to take an average to find the average
    # validation loss, which will be done by adding them together and then
    # dividing by the number of GPUs we are using (a.k.a. world_size)

    # In order to do the sum across devices, the variable needs to be a 
    # tensor with size of at least 1. So it should not be a scalar tensor, if it is
    # you will need to put it into a 1-d tensor.
   
    # Lets assume that val_loss is a scalar, then you will need to do:
    val_loss = torch.Tensor([ val_loss ]).cuda()

    # Then, you perform the reduction (SUM in this case) across all devices
    dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)

    # Finally we can log to tensorboard, but we will only do this on GPU 0
    
    if gpuid == 0:
        # We perform the average by dividing the reduced tensor by the 
        # number of GPUs we are using (a.k.a. world_size)

        val_loss = val_loss / world_size
        
        # Then we log the value, and even can save the model
        # REMEMBER: you should only do this on one GPU, if not
        # you may end up overwriting the saved model.
       
        tensorboard_logger.add_scalar('val/loss', val_loss, epoch)

        # Notice that I save the 'model.module weights', not model weights
        torch.save(model.module.state_dict(), 'model_{}.pt'.format(epoch))
        torch.save(opt.state_dict(), 'optim_{}.pt'.format(epoch))
       
    #Finally we wait for the GPU 0 subprocess to finish and 
    # go back to training mode
    
    model.train()
    dist.barrier()

Here is a small example of how do I use the above code to do parallel training:

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--nodes', default=1, type=int, metavar='N', help='Number of computers to use in training')
    parser.add_argument('--gpus', default=1, type=int, help='number of gpus per node')
    parser.add_argument('--epochs', default=10000, type=int, metavar='eps', help='number of total epochs to run')

    args, unknown = parser.parse_known_args() 
    args.world_size = args.gpus * args.nodes                                
    os.environ['MASTER_ADDR'] = socket.gethostname()             
    os.environ['MASTER_PORT'] = '8888' 

    # Notice that here we don't give the gpuid value, this is given by the mp.spawn method                       
    mp.spawn(train_dist, nprocs=args.gpus, args=( args.world_size, args.epochs) )

if __name__ == '__main__':
    main()

And an example of a function you can use to do the validation pass that will adapt to the GPU device id:

def validation(model, valloader, loss_fn, device):

    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for batch in valloader:
            inputs, ground_truth = batch

            inputs = inputs.to(device)
            ground_truth = ground_truth.to(device)

            output = model(inputs)
            
            # We accumulate the loss and multiply it by the batch size
            # since we will return an average based on the number of
            # samples available in the validation dataloader

            total_loss += loss_fn(output, ground_truth) * inputs.shape[0]
    
    # We return the average loss value
    return total_loss / len(valloader.dataset)

I hope this helps! Let me know if you have any questions.

Hey, thanks a lot for the example. it helps a lot. I think I’m still doing something wrong, though: I have added the distributed sampler to my validation dataloader, and replaced model.module by model in my validation function, so validation is distributed like in your example. I’ve also taken the same steps as you show for validation loss for the training loss.
The first epoch completes on both of my GPUs, then validation takes place on GPU 0, but GPU 1 does nothing. Then because of the barriers it hangs. Any clue what I could be missing? (I’m running with torchrun btw)

Here’s my updated code:

def ddp_setup():
    init_process_group(backend='nccl')

class Trainer:
    def __init__(self, model: torch.nn.Module, train_data: DataLoader, val_data: DataLoader, test_data: DataLoader, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler._LRScheduler, save_every: int, snapshot_path: str, tensorboard_path: str, world_size: int = 2) -> None:
        self.gpu_id = int(os.environ["LOCAL_RANK"])
        self.model = model.to(self.gpu_id)
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.save_every = save_every
        self.snapshot_path = snapshot_path
        self.epochs_run = 0
        self.batch_size = 16
        self.world_size = world_size
        if self.gpu_id == 0:
            self.logger = SummaryWriter(tensorboard_path)
        else:
            self.logger = None
        self.early_stopper = EarlyStopper(patience=10)
        if os.path.exists(snapshot_path):
            print("Loading snapshot")
            self._load_snapshot(snapshot_path)

        self.model = DDP(self.model, device_ids=[
                         self.gpu_id], find_unused_parameters=True)

    def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.gpu_id}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        self.early_stopper = snapshot["EARLY_STOPPER"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    def _run_batch(self, source, targets, epoch):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.l1_loss(output, targets)

        loss.backward()
        self.optimizer.step()
        self.scheduler.step()
        return loss.item()

    def _validate_batch(self, val_source, val_targets):
        with torch.no_grad():
            output = self.model(val_source)
            val_loss = F.l1_loss(output, val_targets)
            return val_loss.item()

    def _get_epoch_val_loss(self):
        val_loss = 0
        for source, targets in self.val_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            val_loss += self._validate_batch(source, targets)*source.shape[0]
        return val_loss/len(self.val_data.dataset)

    def _run_epoch(self, epoch):
        self.train_data.sampler.set_epoch(epoch)
        train_loss = 0
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            train_loss += self._run_batch(source,
                                          targets, epoch)*source.shape[0]

        dist.barrier()

        train_loss /= len(self.train_data.dataset)
        train_loss = torch.Tensor([train_loss]).cuda()
        dist.all_reduce(train_loss, op=dist.ReduceOp.SUM)

        self.model.eval()
        print(f"[{self.gpu_id}] validation...")
        val_loss = torch.Tensor([self._get_epoch_val_loss()]).cuda()

        dist.barrier()
        dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)

        if self.gpu_id == 0:
            val_loss = val_loss/self.world_size
            self.logger.add_scalar("Loss/val", val_loss, epoch)

            train_loss = train_loss/self.world_size
            self.logger.add_scalar("Loss/train", train_loss, epoch)

        self.model.train()
        dist.barrier()

    def _save_snapshot(self, epoch):
        snapshot = {
            "MODEL_STATE": self.model.module.state_dict(),
            "EPOCHS_RUN": epoch,
            "EARLY_STOPPER": self.early_stopper,
        }
        torch.save(snapshot, self.snapshot_path)
        print(
            f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")

    def train(self, max_epochs: int):
        for epoch in range(self.epochs_run, max_epochs):
            print(f"[{self.gpu_id}] epoch {epoch}")
            self._run_epoch(epoch)
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)

    def test(self):
        psnr_values = []
        ssim_values = []
        self.model.eval()
        with torch.no_grad():
            for lr, hr in self.test_data:
                lr = lr.to(self.gpu_id)
                sr = self.model.module(lr).to("cpu")

                psnr_value = calculate_psnr(hr, sr).item()
                psnr_values.append(psnr_value)
                ssim_val = calculate_ssim(hr, sr).item()
                ssim_values.append(ssim_val)
            self.model.train()
            psnr, ssim = (sum(psnr_values)/len(psnr_values),
                          sum(ssim_values)/len(ssim_values))
            print(f"PSNR: {psnr} | SSIM: {ssim}")

    def train_test(self, max_epochs: int):
        self.train(max_epochs)
        print
        if self.gpu_id == 0:
            self.test()


def load_train_objs(dataset_name: str, image_size: int, scale: int):
      train_batch_size = 16

    print("Load data ...")
    test_loader, valid_loader, train_loader = get_swinir_dataloaders(
        dataset_name, train_batch_size)

    model = SwinIR(img_size=image_size, scale=scale, window_size=8, mlp_ratio=2,
                   embed_dim=180, upsampler='pixelshuffle')  # load your model
    optimizer = Adam(model.parameters(), lr=2e-4, weight_decay=0)
    scheduler = MultiStepLR(
        optimizer, [250000, 400000, 450000, 475000, 500000], 0.5)
    return train_loader, valid_loader, test_loader, model, optimizer, scheduler


def main(save_every: int, total_epochs: int, snapshot_path: str, scale: int = 2):
    ddp_setup()
    train_loader, valid_loader, test_loader, model, optimizer, scheduler = load_train_objs(
        "cerrado", 64, scale)
    trainer = Trainer(model, train_loader, valid_loader, test_loader,
                      optimizer, scheduler, save_every, snapshot_path, "/data1/tensorboard/swinir_test")

    trainer.train_test(total_epochs)

    destroy_process_group()

Depending on the size of your dataset, it may not be necessary to do distributed evaluation during training and you can just run evaluation on rank 0. For logging during training, I also define a logging function which checks the rank of each process and only writes to a file on rank , I never make any calls to dist.barrier during training unless I’m doing validation at the end of each epoch. This is the code I use below as part of my training function which always works fine, it logs, evaluates and saves the model at checkpoints:

for epoch in range(args.epochs):
        model.train()
        logger(args,f"Epoch: {epoch}...")
        for ix, (x,y) in enumerate(trainloader):
            # Zero-out gradients
            optim.zero_grad()

            # Sent inputs to GPU
            x = x.to(args.gpu)
            
            # Forward pass
            outputs = model.forward(x)

            # Backward pass
            loss = criterion(outputs,x)
            loss.backward()
            optim.step()

            # Logging
            if ix % args.plot_freq == 0:
                logger(args,f"\tEpoch: {epoch} == Batch: {ix} == MSE: {loss.item():.4f}")
    
    # Checkpoint save
        if epoch % args.save_freq == 0:
            if args.rank==0:
                validate(model.module.to(args.gpu), valloader, params)
                dist.barrier()
            else:
                dist.barrier()
            torch.save(model.module.state_dict(),
            os.path.join(args.model_path, f"{args.arch}_z_{args.latent_dim}_epoch_{epoch}.pth"))

I think just try and keep it simple i.e., avoid making lots of calls to dist.barrier since much of the synchronisation is abstracted away and handled by the backends already.

I finally got it all to work, I don’t fully understand why, but barriers and/or reduces in the wrong places seem to make it all clog up. This is the training function that finally works for me:

    def _run_epoch(self, epoch):
        self.train_data.sampler.set_epoch(epoch)
        train_loss = 0
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            train_loss += self._run_batch(source,
                                          targets, epoch)*source.shape[0]

        dist.barrier()
        self.model.eval()
        val_loss = self._get_epoch_val_loss()
        dist.barrier()
        val_loss = torch.Tensor([val_loss]).to(self.gpu_id)
        dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)

        train_loss /= len(self.train_data.dataset)
        train_loss = torch.Tensor([train_loss]).to(self.gpu_id)
        dist.all_reduce(train_loss, op=dist.ReduceOp.SUM)

        if self.gpu_id == 0:
            self.logger.add_scalar("Loss/val", val_loss/self.world_size, epoch)
            self.logger.add_scalar(
                "Loss/train", train_loss/self.world_size, epoch)
        self.model.train()

thanks! it helps to see some more examples from other people