Logging with Distributed Data Parallel with PyTorch

I am trying to setup a training workflow with PyTorch DistributedDataParallel (DDP). Generally when I train I pass a logger through to track outputs and record useful information. However, I am having trouble using the logger I have with the DDP method. Right now my code is as follows:

import torch
import torch.multiprocessing as mp


class BaseModel:
    def __init__(self, *args, **kwargs):
        ...
        "does things"

    def fit(self, *args, **kwargs):
        ...
        'set up stuff'
        mp.spawn(self.distributed_training, nprocs=self.num_gpus, args=(self.params, training_input, self.logger))
        
    def distributed_training(params, training_input, logger):
        ...
        for e in epochs:
            'trains for an epoch'
            logger.info(print_line)

I know I am supposed to use the QueueHandler and QueueListener tools from logging with the import, but I have been scouring the internet and still do not have a clear understanding as to how. Any help would be greatly appreciated.

What sort of issues do you encounter when running this code?

You could also consider creating a per-spawned process logger and no longer passing in the same logger into the spawned processes. This would also allow you to configure your logging on a per-DDP process basis, for example, write the logs to different files depending on the process.

With the way the code is set-up, things passed to the logger in the spawned process just don’t go through (ie, wont print or save). I’m not sure how a spawned process logger would help, as I need to capture things in the training. Mainly, I’m just trying to record loss and metrics per-epoch, nothing I would consider overly special for a training process. I think the issue is more that the logger is initiated on the main process so sending it to some type of other process is causing issues (whether it would be on its own process or in the distributed process).

1 Like

Having the same issue here.