Using tensorboard with DistributedDataParallel

Hello, I am trying to make my workflow run on multiple GPUs. Since torch.nn.DataParallel did not work out for me (see this discussion), I am now trying to go with torch.nn.parallel.DistributedDataParallel (DDP). However I am not sure how to use the tensorboard logger when doing distributed training. Previous questions about this topic remain unanswered: (here or here).
I have set up a typical training workflow that runs fine without DDP (use_distributed_training=False) but fails when using it with the error: TypeError: cannot pickle '_io.BufferedWriter' object.
Is there any way to make this code run, using both tensorboard and DDP?

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tensorboardX import SummaryWriter
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.distributions import Laplace


class ToyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.dens1 = nn.Linear(in_features=16, out_features=3)

    def forward(self, x):
        x = self.dens1(x)
        x = Laplace(x, torch.tensor([1.0]))
        return x


class RandomDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        sample = {'mod1': torch.rand(1, 16).float(),
                  'mod2': torch.rand(1, 16).float(),
                  'mod3': torch.rand(1, 16).float()}

        label = torch.randint(0, 1, (3,)).float()
        return sample, label

    def __len__(self):
        return 20


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class Experiment:
    def __init__(self, distributed, dir_logs):
        # initialize summary writer
        if not os.path.exists(dir_logs):
            os.makedirs(dir_logs)
        self.logger = SummaryWriter(dir_logs)
        self.model = ToyNet()
        self.rank = None
        self.distributed = distributed
        if distributed:
            self.world_size = torch.cuda.device_count()
            assert self.world_size > 1, 'More than 1 GPU need to be accessible to use distributed training'
        else:
            self.world_size = 1


def train(exp: Experiment):
    rank = exp.rank
    if exp.distributed:
        model = DDP(exp.model, device_ids=[rank])
        sampler = DistributedSampler(RandomDataset(), num_replicas=exp.world_size, rank=rank)
    else:
        model = exp.model.to(rank)
        sampler = None
    rand_loader = DataLoader(dataset=RandomDataset(),
                             batch_size=8, shuffle=False, pin_memory=True, sampler=sampler, num_workers=0)

    mse_loss = nn.MSELoss()
    for step, (batch, label) in enumerate(rand_loader):
        for modality in batch.keys():
            label = label.to(rank)
            batch = {k: v.to(rank) for k, v in batch.items()}
            output = model(batch[modality]).mean
            loss = mse_loss(output, label)
            exp.logger.add_scalars(f'train/loss',
                                   {'train_loss': loss.item()},
                                   step)


def validate(exp):
    model = exp.model.eval()
    rank = exp.rank
    with torch.no_grad():
        if exp.distributed:
            sampler = DistributedSampler(RandomDataset(), num_replicas=exp.world_size, rank=rank)
        else:
            sampler = None
        rand_loader = DataLoader(dataset=RandomDataset(),
                                 batch_size=8, shuffle=False, pin_memory=True, sampler=sampler, num_workers=0)
        mse_loss = nn.MSELoss()
        for step, (batch, label) in enumerate(rand_loader):
            for modality in batch.keys():
                label = label.to(rank)
                batch = {k: v.to(rank) for k, v in batch.items()}
                output = model(batch[modality]).mean
                loss = mse_loss(output, label)
                exp.logger.add_scalars(f'val/loss',
                                       {'val_loss': loss.item()},
                                       step)


def run_epochs(rank, exp: Experiment):
    print(f"Running basic DDP example on rank {rank}.")
    exp.rank = rank
    if exp.distributed:
        setup(rank, exp.world_size)

    for epoch in range(5):
        train(exp)
        validate(exp)

    if exp.distributed:
        cleanup()
    print('done!')


if __name__ == '__main__':
    log_dir = 'temp_dir'
    use_distributed_training = True
    ex = Experiment(use_distributed_training, log_dir)
    if ex.distributed:
        mp.spawn(run_epochs,
                 args=(ex,),
                 nprocs=ex.world_size,
                 join=True)
    else:
        run_epochs(torch.device('cuda'), ex)

1 Like

The only option seems to be to only log one process. This code runs fine:

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tensorboardX import SummaryWriter
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.distributions import Laplace


class ToyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.dens1 = nn.Linear(in_features=16, out_features=3)

    def forward(self, x):
        x = self.dens1(x)
        x = Laplace(x, torch.tensor([1.0]))
        return x


class RandomDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        sample = {'mod1': torch.rand(1, 16).float(),
                  'mod2': torch.rand(1, 16).float(),
                  'mod3': torch.rand(1, 16).float()}

        label = torch.randint(0, 1, (3,)).float()
        return sample, label

    def __len__(self):
        return 20


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class Experiment:
    def __init__(self, distributed: bool, dir_logs: str):
        self.logger = None
        self.dir_logs = dir_logs
        self.model = ToyNet()
        self.rank = None
        self.distributed = distributed
        if distributed:
            self.world_size = torch.cuda.device_count()
            assert self.world_size > 1, 'More than 1 GPU need to be accessible to use distributed training'
        else:
            self.world_size = 1

    def setup_logger(self):
        # initialize summary writer
        if not os.path.exists(self.dir_logs):
            os.makedirs(self.dir_logs)
        self.logger = SummaryWriter(self.dir_logs)


def train(exp: Experiment, rand_loader: DataLoader):
    rank = exp.rank
    model = exp.model.to(rank)
    if exp.distributed:
        model = DDP(exp.model, device_ids=[rank])

    mse_loss = nn.MSELoss()
    for step, (batch, label) in enumerate(rand_loader):
        for modality in batch.keys():
            label = label.to(rank)
            batch = {k: v.to(rank) for k, v in batch.items()}
            output = model(batch[modality]).mean
            loss = mse_loss(output, label)
            if exp.logger:
                exp.logger.add_scalars(f'train/loss',
                                       {'train_loss': loss.item()},
                                       step)
            loss.backward()


def validate(exp, rand_loader: DataLoader):
    rank = exp.rank
    model = exp.model.eval()
    with torch.no_grad():

        mse_loss = nn.MSELoss()
        for step, (batch, label) in enumerate(rand_loader):
            for modality in batch.keys():
                label = label.to(rank)
                batch = {k: v.to(rank) for k, v in batch.items()}
                output = model(batch[modality]).mean
                loss = mse_loss(output, label)
                if exp.logger:
                    exp.logger.add_scalars(f'val/loss',
                                           {'val_loss': loss.item()},
                                           step)


def run_epochs(rank: any, exp: Experiment):
    print(f"Running basic DDP example on rank {rank}.")
    exp.rank = rank
    if not exp.distributed or (rank % exp.world_size == 0):
        print(f'setting up logger for rank {rank}')
        exp.setup_logger()
    if exp.distributed:
        setup(rank, exp.world_size)
        sampler = DistributedSampler(RandomDataset(), num_replicas=exp.world_size, rank=rank)
    else:
        sampler = None
    rand_loader = DataLoader(dataset=RandomDataset(),
                             batch_size=8, shuffle=False, pin_memory=True, sampler=sampler, num_workers=0)
    for epoch in range(5):
        if exp.distributed:
            sampler.set_epoch(epoch)
        train(exp, rand_loader)
        validate(exp, rand_loader)

    if exp.distributed:
        cleanup()
    if exp.logger:
        exp.logger.close()
    print('done!')


if __name__ == '__main__':
    log_dir = 'temp_dir'
    use_distributed_training = True
    ex = Experiment(use_distributed_training, log_dir)
    if ex.distributed:
        mp.spawn(run_epochs,
                 args=(ex,),
                 nprocs=ex.world_size,
                 join=True)
    else:
        run_epochs(torch.device('cuda'), ex)

1 Like

CC @orionr for Tensorboard question

@Jimmy2027: I was able to make logging work by moving SummaryWriter creation from main process to child process, specifically remove

self.logger = SummaryWriter(dir_logs)

And add in run_epochs

exp.logger = SummaryWriter(exp.dir_logs)

So that we don’t have to folk the lock inside SummaryWriter (in _AsyncWriter https://github.com/tensorflow/tensorboard/blob/master/tensorboard/summary/writer/event_file_writer.py#L163). In general each child process should create their own SummaryWriter instead of forking from parent process.

Also unrelated to your issue, tensorboardX has long been deprecated and no longer actively maintained, being replaced by pytorch native support for TensorBoard since Pytorch 1.2. To use it simply replace

from tensorboardX import SummaryWriter

With

from torch.utils.tensorboard import SummaryWriter 
5 Likes

Thanks a lot @cryptopic, that works :slight_smile:

Thanks @cryptopic works fine. I am surprised that there is no need to to a post-processing of the logged data, does tensorboard joins traumatically the data from all processes?

does tensorboard joins automatically the data from all processes?

Yes, different processes will write to different log files, and TensorBoard will aggregate all log files during visualization

Yes you are right. Just as a note here, when using this we have to bare in mind that, as you say, different processes will write different log files and TB will aggregate all for visualization. So here the problem is if you write multiple values for the same variables which gives you crazy charts like this:
Selection_706
This can be solved by using RANK variable, so that only one process will write a log file. For example, I have done something like:

    if args.rank==0:
        dir = os.path.join(args.output_dir, 'logs')
        logger = SummaryWriter(dir)
        print('wrigint on {}'.format(dir))
        logger.add_scalar('error', value, 0)
        logger.add_scalar('error', value, 1)
        logger.add_scalar('error', value, 2)
        logger.add_scalar('error', value, 3)
        logger.close()

I am curious how do you deal with this? Is there a more interesting way of doing this?

4 Likes

Hey Nicolas,

Wondering if you have made any progress on this. Also face the same issue here :frowning:

3 Likes