Why do we have to create logger in process for correct logging in DDP

Logging prints nothing in the following code:

#!/usr/bin/python
# -*- coding: UTF-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals

import os, logging
#logging.basicConfig(level=logging.DEBUG)

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # Initialize the process group.
    dist.init_process_group('NCCL', rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    setup(rank, world_size)

    logger = logging.getLogger('train')
    logger.setLevel(logging.DEBUG)
    logger.info(f'Running DPP on rank={rank}.')

    # Create model and move it to GPU.
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)  # optimizer takes DDP model.

    optimizer.zero_grad()
    inputs = torch.randn(20, 10)  # .to(rank)

    outputs = ddp_model(inputs)

    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()

    optimizer.step()

    cleanup()


def run_demo(demo_func, world_size):
    mp.spawn(
        demo_func,
        args=(world_size,),
        nprocs=world_size,
        join=True
    )


def main():
    run_demo(demo_basic, 4)


if __name__ == "__main__":
    main()

However, when we uncomment the 6th line, the logging works. May I know the reason and how to fix the bug please?

Hi,

It doesn’t seem to be related to DDP or pytorch, but to how logging module is setup. If you remove all the torch code, you would still get the same result.

def main():
    logger = logging.getLogger('train')
    logger.setLevel(logging.DEBUG)
    logger.info(f'in main.')

Does it block you in any way?

Hi @agolynski, thank you so much for your kind reply. I have adjusted my code and found that the logger works very well if it is created inside the process of DDP but fails again if it was fed as argument. The following snippets can evidence my statement:

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # Initialize the process group.
    dist.init_process_group('NCCL', rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size, logger=None):
    setup(rank, world_size)

    if rank == 0:
        logger = get_logger() if logger is None else logger
        logger.info(f'info in process')
        logger.error(f'error in process.')

    # Create model and move it to GPU.
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)  # optimizer takes DDP model.

    optimizer.zero_grad()
    inputs = torch.randn(20, 10)  # .to(rank)

    outputs = ddp_model(inputs)

    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()

    optimizer.step()

    cleanup()


def run_demo(demo_func, world_size):
    logger = get_logger()
    # logger = None  # Created from inside.
    mp.spawn(
        demo_func,
        args=(world_size, logger),
        nprocs=world_size,
        join=True
    )


def get_logger():
    logger = logging.getLogger('train')

    # Handlers.
    logger.addHandler(
        logging.StreamHandler()
    )
    logger.setLevel(logging.DEBUG)

    return logger


def example2():
    run_demo(demo_basic, 4)


def main():
    example2()


if __name__ == "__main__":
    main()

If the code in line 54 is commented (as above), there is no “info in process”. However, if we commented line 53 and uncommented line 54, we can see “info in process” in the output.

It does not block me, but I am quite curious why it happens, I thought DDP is essentially a wrapper of process.