Using tensorboard with DistributedDataParallel

The only option seems to be to only log one process. This code runs fine:

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tensorboardX import SummaryWriter
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.distributions import Laplace


class ToyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.dens1 = nn.Linear(in_features=16, out_features=3)

    def forward(self, x):
        x = self.dens1(x)
        x = Laplace(x, torch.tensor([1.0]))
        return x


class RandomDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        sample = {'mod1': torch.rand(1, 16).float(),
                  'mod2': torch.rand(1, 16).float(),
                  'mod3': torch.rand(1, 16).float()}

        label = torch.randint(0, 1, (3,)).float()
        return sample, label

    def __len__(self):
        return 20


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class Experiment:
    def __init__(self, distributed: bool, dir_logs: str):
        self.logger = None
        self.dir_logs = dir_logs
        self.model = ToyNet()
        self.rank = None
        self.distributed = distributed
        if distributed:
            self.world_size = torch.cuda.device_count()
            assert self.world_size > 1, 'More than 1 GPU need to be accessible to use distributed training'
        else:
            self.world_size = 1

    def setup_logger(self):
        # initialize summary writer
        if not os.path.exists(self.dir_logs):
            os.makedirs(self.dir_logs)
        self.logger = SummaryWriter(self.dir_logs)


def train(exp: Experiment, rand_loader: DataLoader):
    rank = exp.rank
    model = exp.model.to(rank)
    if exp.distributed:
        model = DDP(exp.model, device_ids=[rank])

    mse_loss = nn.MSELoss()
    for step, (batch, label) in enumerate(rand_loader):
        for modality in batch.keys():
            label = label.to(rank)
            batch = {k: v.to(rank) for k, v in batch.items()}
            output = model(batch[modality]).mean
            loss = mse_loss(output, label)
            if exp.logger:
                exp.logger.add_scalars(f'train/loss',
                                       {'train_loss': loss.item()},
                                       step)
            loss.backward()


def validate(exp, rand_loader: DataLoader):
    rank = exp.rank
    model = exp.model.eval()
    with torch.no_grad():

        mse_loss = nn.MSELoss()
        for step, (batch, label) in enumerate(rand_loader):
            for modality in batch.keys():
                label = label.to(rank)
                batch = {k: v.to(rank) for k, v in batch.items()}
                output = model(batch[modality]).mean
                loss = mse_loss(output, label)
                if exp.logger:
                    exp.logger.add_scalars(f'val/loss',
                                           {'val_loss': loss.item()},
                                           step)


def run_epochs(rank: any, exp: Experiment):
    print(f"Running basic DDP example on rank {rank}.")
    exp.rank = rank
    if not exp.distributed or (rank % exp.world_size == 0):
        print(f'setting up logger for rank {rank}')
        exp.setup_logger()
    if exp.distributed:
        setup(rank, exp.world_size)
        sampler = DistributedSampler(RandomDataset(), num_replicas=exp.world_size, rank=rank)
    else:
        sampler = None
    rand_loader = DataLoader(dataset=RandomDataset(),
                             batch_size=8, shuffle=False, pin_memory=True, sampler=sampler, num_workers=0)
    for epoch in range(5):
        if exp.distributed:
            sampler.set_epoch(epoch)
        train(exp, rand_loader)
        validate(exp, rand_loader)

    if exp.distributed:
        cleanup()
    if exp.logger:
        exp.logger.close()
    print('done!')


if __name__ == '__main__':
    log_dir = 'temp_dir'
    use_distributed_training = True
    ex = Experiment(use_distributed_training, log_dir)
    if ex.distributed:
        mp.spawn(run_epochs,
                 args=(ex,),
                 nprocs=ex.world_size,
                 join=True)
    else:
        run_epochs(torch.device('cuda'), ex)

1 Like