Distributed training with CPU's

Hi!

I am interested in possibly using Ignite to enable distributed training in CPU’s (since I am training a shallow network and have no GPU"s available). I tried using ignite.distributed with the gloo backend, but when I set nproc_per_node to more than 1, the program gets stuck and doesn’t run (it does without setting nproc_per_node). The code is practically the same as the CIFAR example.

My question is, is it possible to train PyTorch models on multiple CPU’s be it with or without Ignite?

Thanks in advance!

@lesscomfortable can you share the code and the command you execute to reproduce the issue.

Here is how it would run CIFAR10 script on CPU multi-core (single node) in distributed way:

CUDA_VISIBLE_DEVICES="" python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py run --backend=gloo

To ensure that it is not a visual effect that program gets stuck as a single epoch on cifar10 on CPU can several minutes to execute. You can add the progress bar:

    trainer = Engine(train_step)

    if idist.get_rank() == 0 and (not config["with_clearml"]):
        common.ProgressBar(persist=False).attach(trainer)

In my case it looks like:

Epoch [1/24]: [4/97]   4%|██████▏                              [00:07<03:45]

HTH

1 Like

Yes, here is the code. It is almost identical to that in the CIFAR example, except that I’m giving the dataloaders and model to the function myself, I’m using a different loss function and I added a progress bar for training.

def training(local_rank, config):

    rank = idist.get_rank()
    manual_seed(config["seed"] + rank)
    device = idist.device()

    logger = setup_logger(name="NN-Training")

    log_basic_info(logger, config)

    output_path = config["output_path"]
    if rank == 0:
        if config["stop_iteration"] is None:
            now = datetime.now().strftime("%Y%m%d-%H%M%S")
        else:
            now = f"stop-on-{config['stop_iteration']}"

        folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}"
        output_path = Path(output_path) / folder_name
        if not output_path.exists():
            output_path.mkdir(parents=True)
        config["output_path"] = output_path.as_posix()
        logger.info(f"Output path: {config['output_path']}")

        if "cuda" in device.type:
            config["cuda device name"] = torch.cuda.get_device_name(local_rank)

    # Setup dataflow, model, optimizer, criterion
    train_loader, test_loader = config['dataloaders']

    config["num_iters_per_epoch"] = len(train_loader)
    model, optimizer, criterion, lr_scheduler = initialize(config)

    # Create trainer for current task
    trainer = create_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger)
    
    if idist.get_rank() == 0 and (not config["with_clearml"]):
        common.ProgressBar(desc=f"Trainer (train)", persist=False).attach(trainer, output_transform=lambda x: {'batch loss': x['batch loss']})

    # Let's now setup evaluator engine to perform model's validation and compute metrics
    metrics = {
        "MeanAbsoluteError": MeanAbsoluteError(),
        "Loss": Loss(criterion),
    }

    # We define two evaluators as they wont have exactly similar roles:
    # - `evaluator` will save the best model based on validation score
    evaluator = create_evaluator(model, metrics=metrics, config=config)
    train_evaluator = create_evaluator(model, metrics=metrics, config=config)

    def run_validation(engine):
        epoch = trainer.state.epoch
        state = train_evaluator.run(train_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics)
        state = evaluator.run(test_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics)

    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED, run_validation)

    if rank == 0:
        # Setup TensorBoard logging on trainer and evaluators. Logged values are:
        #  - Training metrics, e.g. running average loss values
        #  - Learning rate
        #  - Evaluation train/test metrics
        evaluators = {"training": train_evaluator, "test": evaluator}
        tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators)

    # Store 2 best models by validation mae starting from num_epochs / 2:
    best_model_handler = Checkpoint(
        {"model": model},
        get_save_handler(config),
        filename_prefix="best",
        n_saved=2,
        global_step_transform=global_step_from_engine(trainer),
        score_name="test_mae",
        score_function=Checkpoint.get_default_score_fn("MeanAbsoluteError"),
    )
    evaluator.add_event_handler(
        Events.COMPLETED(lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler
    )

    # In order to check training resuming we can stop training on a given iteration
    if config["stop_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["stop_iteration"]))
        def _():
            logger.info(f"Stop training on {trainer.state.iteration} iteration")
            trainer.terminate()

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        logger.exception("")
        raise e

    if rank == 0:
        tb_logger.close()


def run(
    model,
    dataloaders,
    seed=543,
    data_path="/tmp/nn",
    output_path="/tmp/output-nn/",
    batch_size=512,
    momentum=0.9,
    weight_decay=1e-4,
    num_workers=12,
    num_epochs=24,
    learning_rate=0.4,
    num_warmup_epochs=4,
    validate_every=1,
    checkpoint_every=1000,
    backend=None,
    resume_from=None,
    log_every_iters=15,
    nproc_per_node=None,
    stop_iteration=None,
    with_clearml=False,
    with_amp=False,
    **spawn_kwargs,
):
    """Main entry to train an model on NN dataset.
    Args:
        seed (int): random state seed to set. Default, 543.
        data_path (str): input dataset path. Default, "/tmp/nn".
        output_path (str): output path. Default, "/tmp/output-nn".
        model (str): model name (from torchvision) to setup model to train. Default, "resnet18".
        batch_size (int): total batch size. Default, 512.
        momentum (float): optimizer's momentum. Default, 0.9.
        weight_decay (float): weight decay. Default, 1e-4.
        num_workers (int): number of workers in the data loader. Default, 12.
        num_epochs (int): number of epochs to train the model. Default, 24.
        learning_rate (float): peak of piecewise linear learning rate scheduler. Default, 0.4.
        num_warmup_epochs (int): number of warm-up epochs before learning rate decay. Default, 4.
        validate_every (int): run model's validation every ``validate_every`` epochs. Default, 3.
        checkpoint_every (int): store training checkpoint every ``checkpoint_every`` iterations. Default, 1000.
        backend (str, optional): backend to use for distributed configuration. Possible values: None, "nccl", "xla-tpu",
            "gloo" etc. Default, None.
        nproc_per_node (int, optional): optional argument to setup number of processes per node. It is useful,
            when main python process is spawning training as child processes.
        resume_from (str, optional): path to checkpoint to use to resume the training from. Default, None.
        log_every_iters (int): argument to log batch loss every ``log_every_iters`` iterations.
            It can be 0 to disable it. Default, 15.
        stop_iteration (int, optional): iteration to stop the training. Can be used to check resume from checkpoint.
        with_clearml (bool): if True, experiment ClearML logger is setup. Default, False.
        with_amp (bool): if True, enables native automatic mixed precision. Default, False.
        **spawn_kwargs: Other kwargs to spawn run in child processes: master_addr, master_port, node_rank, nnodes
    """
    # catch all local parameters
    config = locals()
    config['dataloaders'] = dataloaders
    config.update(config["spawn_kwargs"])
    del config["spawn_kwargs"]

    spawn_kwargs["nproc_per_node"] = nproc_per_node
    if backend == "xla-tpu" and with_amp:
        raise RuntimeError("The value of with_amp should be False if backend is xla")

    with idist.Parallel(backend=backend, **spawn_kwargs) as parallel:
        parallel.run(training, config)


def get_dataflow(config):
    # - Get train/test datasets
    if idist.get_local_rank() > 0:
        # Ensure that only local rank 0 download the dataset
        # Thus each node will download a copy of the dataset
        idist.barrier()

    train_dataset, test_dataset = utils.get_train_test_datasets(config["data_path"])

    if idist.get_local_rank() == 0:
        # Ensure that only local rank 0 download the dataset
        idist.barrier()

    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_dataset, batch_size=config["batch_size"], num_workers=config["num_workers"], shuffle=True, drop_last=True,
    )

    test_loader = idist.auto_dataloader(
        test_dataset, batch_size=2 * config["batch_size"], num_workers=config["num_workers"], shuffle=False,
    )
    return train_loader, test_loader


def initialize(config):
    model = config["model"]
    # Adapt model for distributed settings if configured
    model = idist.auto_model(model)

    optimizer = optim.SGD(
        model.parameters(),
        lr=config["learning_rate"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
        nesterov=True,
    )
    optimizer = idist.auto_optim(optimizer)
    criterion = nn.L1Loss().to(idist.device())

    le = config["num_iters_per_epoch"]
    milestones_values = [
        (0, 0.0),
        (le * config["num_warmup_epochs"], config["learning_rate"]),
        (le * config["num_epochs"], 0.0),
    ]
    lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)

    return model, optimizer, criterion, lr_scheduler


def log_metrics(logger, epoch, elapsed, tag, metrics):
    metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()])
    logger.info(f"\nEpoch {epoch} - Evaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}")


def log_basic_info(logger, config):
    logger.info(f"Train {config['model']}")
    logger.info(f"- PyTorch version: {torch.__version__}")
    logger.info(f"- Ignite version: {ignite.__version__}")
    if torch.cuda.is_available():
        # explicitly import cudnn as
        # torch.backends.cudnn can not be pickled with hvd spawning procs
        from torch.backends import cudnn

        logger.info(f"- GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}")
        logger.info(f"- CUDA version: {torch.version.cuda}")
        logger.info(f"- CUDNN version: {cudnn.version()}")

    logger.info("\n")
    logger.info("Configuration:")
    for key, value in config.items():
        logger.info(f"\t{key}: {value}")
    logger.info("\n")

    if idist.get_world_size() > 1:
        logger.info("\nDistributed setting:")
        logger.info(f"\tbackend: {idist.backend()}")
        logger.info(f"\tworld size: {idist.get_world_size()}")
        logger.info("\n")


def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            y_pred = model(x)
            loss = criterion(y_pred, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer


def create_evaluator(model, metrics, config, tag="val"):
    with_amp = config["with_amp"]
    device = idist.device()

    @torch.no_grad()
    def evaluate_step(engine: Engine, batch):
        model.eval()
        x, y = batch[0], batch[1]
        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        with autocast(enabled=with_amp):
            output = model(x)
        return output, y

    evaluator = Engine(evaluate_step)

    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    if idist.get_rank() == 0 and (not config["with_clearml"]):
        common.ProgressBar(desc=f"Evaluation ({tag})", persist=False).attach(evaluator)

    return evaluator


def get_save_handler(config):
    if config["with_clearml"]:
        from ignite.contrib.handlers.clearml_logger import ClearMLSaver

        return ClearMLSaver(dirname=config["output_path"])

    return DiskSaver(config["output_path"], require_empty=False)

One notable difference with your command is that I am running this in a Jupyter Notebook in the following manner (no torch.distributed.launch in my code):

res = run(
    model, 
    [train_dl, test_dl], 
    backend='gloo',
    nproc_per_node=4,
       )

The result is that the program hangs after printing the following log:

ignite.distributed.launcher.Parallel INFO: Spawn function '<function training at ...>' in 4 processes

Should I run this as a script with torch.distributed.launch?

Thanks for the prompt answer!

Here is how to launch the code on Jupyter.

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import time
import ignite.distributed as idist

def training(local_rank, config, **kwargs):
    time.sleep(local_rank)
    print(idist.get_rank(), ': run with config:', config, '- backend=', idist.backend())
    # do the training ...

backend = 'gloo'
dist_configs = {'nproc_per_node': 4, "start_method": "fork"}
config = {'c': 12345}

with idist.Parallel(backend=backend, **dist_configs) as parallel:
    parallel.run(training, config, a=1, b=2)

You have to use start_method="fork".

If you would like to run it as a script file and spawn processes from your main.py script as you do, then you can use default start_method. Also, it could be helpful to set persistent_workers=True for the DataLoader to speed up data fetching every epoch…
If you would like to use a script file and spawn processes with torch.distributed.launch, you can simply reuse the command from my previous message (and no need to set persistent_workers=True).

Thanks for the tips!

I included all the code necessary to run the code in Jupyter and it now prints out the training info 4, each time with a different rank (1 through 4). However it still hangs, and htop shows that the CPU cores are unused :grimacing:. Any other clues on what might be going on?

Just in case, these are my parameters:

|2021-08-09 20:40:58,801 NN-Training INFO: |seed: 543|
|2021-08-09 20:40:58,802 NN-Training INFO: |output_path: test|
|2021-08-09 20:40:58,803 NN-Training INFO: |momentum: 0.9|
|2021-08-09 20:40:58,804 NN-Training INFO: |weight_decay: 0.0001|
|2021-08-09 20:40:58,805 NN-Training INFO: |num_workers: 12|
|2021-08-09 20:40:58,806 NN-Training INFO: |num_epochs: 5|
|2021-08-09 20:40:58,807 NN-Training INFO: |learning_rate: 0.05|
|2021-08-09 20:40:58,809 NN-Training INFO: |num_warmup_epochs: 1|
|2021-08-09 20:40:58,810 NN-Training INFO: |validate_every: 1|
|2021-08-09 20:40:58,811 NN-Training INFO: |checkpoint_every: 3|
|2021-08-09 20:40:58,813 NN-Training INFO: |backend: gloo|
|2021-08-09 20:40:58,818 NN-Training INFO: |resume_from: None|
|2021-08-09 20:40:58,819 NN-Training INFO: |log_every_iters: 15|
|2021-08-09 20:40:58,820 NN-Training INFO: |nproc_per_node: 4|
|2021-08-09 20:40:58,821 NN-Training INFO: |stop_iteration: None|
|2021-08-09 20:40:58,822 NN-Training INFO: |with_clearml: False|
|2021-08-09 20:40:58,823 NN-Training INFO: |with_amp: False|
|2021-08-09 20:40:58,824 NN-Training INFO: |start_method: fork|

@lesscomfortable this is strange. OK, can you please provide pytorch and ignite versions and confirm that you can not run this code snippet: Distributed training with CPU's - #4 by vfdev-5 ?

I created a colab to show that above code snippet is running correctly on CPUs: Google Colaboratory

If you could modify it to bring your code and execute it there (with some random data) and reproduce the issue, it would be nice.

This is something I should probably have included in my params: I’m using pytorch 1.6.0. However, your example runs in Colab and my environment successfully when using pytorch 1.6.0 (the problem must be in some other part of my code). These are my versions:

2021-08-10 09:36:14,511 NN-Training INFO: - PyTorch version: 1.6.0
2021-08-10 09:36:14,511 NN-Training INFO: - Ignite version: 0.4.6

I’ll try with my code in the Google Colab and let you know if I could reproduce the problem.

Thanks!

Sounds good! Let me know if you can create a reproducible colab with the issue.

By the way, inspecting the current code, I wonder how do you create data loaders. The way it is done it is unlikely that the training would do a DDP. Data will be the same for all processes and this is probably undesired.

After reading a little bit, I am creating them like this (I changed this from the original code I sent previously in this thread):

train_sampler, valid_sampler = DistributedSampler(train_ds), DistributedSampler(valid_ds)

train_loader = idist.auto_dataloader(
        train_ds, batch_size=bs, num_workers=n_workers, sampler=train_sampler, drop_last=True
)

test_loader = idist.auto_dataloader(
       valid_ds, batch_size=bs, num_workers=n_workers, sampler=valid_sampler
)

If you use idist.auto_dataloader you do not need to specify distributed sampler if using the default one.

However, dataloader definition should be inside the training function, such that distributed group is initialized and rank, world size are defined.

Ok, I finally made it work! I was trying to reproduce the code in Colab taking your suggestions into account and it now works correctly (I verified that all cpu cores are being used fully). I don’t know exactly what made the trick but here is the working code:


def training(local_rank, config, model, datasets):
        
    rank = idist.get_rank()

    print(f'Running with rank {rank} and local rank {rank}')

    manual_seed(config["seed"] + rank)
    device = idist.device()

    logger = setup_logger(name="NN-Training")

    log_basic_info(logger, config)

    output_path = config["output_path"]
    if rank == 0:
        if config["stop_iteration"] is None:
            now = datetime.now().strftime("%Y%m%d-%H%M%S")
        else:
            now = f"stop-on-{config['stop_iteration']}"

        folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}"
        output_path = Path(output_path) / folder_name
        if not output_path.exists():
            output_path.mkdir(parents=True)
        config["output_path"] = output_path.as_posix()
        logger.info(f"Output path: {config['output_path']}")

        if "cuda" in device.type:
            config["cuda device name"] = torch.cuda.get_device_name(local_rank)

    train_ds, test_ds = datasets
    
    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
    train_loader = idist.auto_dataloader(
        train_ds, 
        collate_fn=collate,
        batch_size=bs, 
        num_workers=n_workers, 
        drop_last=True
    )

    test_loader = idist.auto_dataloader(
       valid_ds, 
       collate_fn=collate,
       batch_size=bs, 
       num_workers=n_workers, 
    )

    config["num_iters_per_epoch"] = len(train_loader)
    model, optimizer, criterion, lr_scheduler = initialize(model, config)

    # Create trainer for current task
    trainer = create_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger)

    # Let's now setup evaluator engine to perform model's validation and compute metrics
    metrics = {
        "mae": MeanAbsoluteError()
    }
    
    # Setup running average metric
    acc_metric = RunningAverage(MeanAbsoluteError(output_transform=lambda x: [x['y'], x['y_pred']]), alpha=0.98)
    acc_metric.attach(trainer, 'running_avg_mae')
    
    @trainer.on(Events.EPOCH_COMPLETED(every=1) | Events.COMPLETED)
    def log_running_avg_metrics(engine):
        epoch = trainer.state.epoch
        avg_mae = engine.state.metrics['running_avg_mae']
        logger.info(f"\nEpoch {epoch} - Train metrics:\n running average mae: {avg_mae}")
    
    # - `evaluator` will save the best model based on validation score
    evaluator = create_evaluator(model, metrics=metrics, config=config)

    @trainer.on(Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED)
    def run_validation(engine):
        epoch = trainer.state.epoch
        state = evaluator.run(test_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics)

    if rank == 0:
        # Setup TensorBoard logging on trainer and evaluators. Logged values are:
        #  - Training metrics, e.g. running average loss values
        #  - Learning rate
        #  - Evaluation test metrics
        evaluators = {"test": evaluator}
        tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators)

    # Store 2 best models by validation accuracy starting from num_epochs / 2:
    best_model_handler = Checkpoint(
        {"model": model},
        get_save_handler(config),
        filename_prefix="best",
        n_saved=2,
        global_step_transform=global_step_from_engine(trainer),
        score_name="test_mae",
        score_function=Checkpoint.get_default_score_fn("mae"),
    )
    evaluator.add_event_handler(
        Events.COMPLETED(lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler
    )

    # In order to check training resuming we can stop training on a given iteration
    if config["stop_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["stop_iteration"]))
        def _():
            logger.info(f"Stop training on {trainer.state.iteration} iteration")
            trainer.terminate()

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        logger.exception("")
        raise e

    if rank == 0:
        tb_logger.close()


def run(
    model,
    datasets,
    seed=543,
    output_path="results/",
    momentum=0.9,
    weight_decay=1e-4,
    num_workers=12,
    num_epochs=5,
    learning_rate=0.05,
    num_warmup_epochs=4,
    validate_every=1,
    checkpoint_every=3,
    backend='gloo',
    resume_from=None,
    log_every_iters=15,
    nproc_per_node=None,
    stop_iteration=None,
    with_clearml=False,
    with_amp=False,
    **spawn_kwargs,
):
    # catch all local parameters
    config = locals()
    config.update(config["spawn_kwargs"])
    del config["spawn_kwargs"]

    spawn_kwargs["nproc_per_node"] = nproc_per_node
    if backend == "xla-tpu" and with_amp:
        raise RuntimeError("The value of with_amp should be False if backend is xla")

    with idist.Parallel(backend=backend, **spawn_kwargs) as parallel:
        parallel.run(training, config, model, datasets)

def initialize(model, config):
    # Adapt model for distributed settings if configured
    model = idist.auto_model(model)

    optimizer = optim.SGD(
        model.parameters(),
        lr=config["learning_rate"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
        nesterov=True,
    )
    optimizer = idist.auto_optim(optimizer)
    criterion = nn.L1Loss().to(idist.device())

    le = config["num_iters_per_epoch"]
    milestones_values = [
        (0, 0.0),
        (le * config["num_warmup_epochs"], config["learning_rate"]),
        (le * config["num_epochs"], 0.0),
    ]
    
    lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)

    return model, optimizer, criterion, lr_scheduler


def log_metrics(logger, epoch, elapsed, tag, metrics):
    metrics_output = "\n".join([f"{k}: {v}" for k, v in metrics.items()])
    logger.info(f"\nEpoch {epoch} - {tag} metrics:\n {metrics_output}")

def log_basic_info(logger, config):
    logger.info(f"Train {config['model']} on CIFAR10")
    logger.info(f"- PyTorch version: {torch.__version__}")
    logger.info(f"- Ignite version: {ignite.__version__}")
    if torch.cuda.is_available():
        # explicitly import cudnn as
        # torch.backends.cudnn can not be pickled with hvd spawning procs
        from torch.backends import cudnn

        logger.info(f"- GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}")
        logger.info(f"- CUDA version: {torch.version.cuda}")
        logger.info(f"- CUDNN version: {cudnn.version()}")

    logger.info("\n")
    logger.info("Configuration:")
    for key, value in config.items():
        logger.info(f"\t{key}: {value}")
    logger.info("\n")

    if idist.get_world_size() > 1:
        logger.info("\nDistributed setting:")
        logger.info(f"\tbackend: {idist.backend()}")
        logger.info(f"\tworld size: {idist.get_world_size()}")
        logger.info("\n")

def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger, tag='train'):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]
        
        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            y_pred = model(x)
            loss = criterion(y_pred, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
            "y": y,
            "y_pred": y_pred
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )
    
    if idist.get_rank() == 0 and (not config["with_clearml"]):
        common.ProgressBar(desc=f"Trainer ({tag})", persist=False).attach(trainer, output_transform=lambda x: {'batch loss': x['batch loss']})

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer

def create_evaluator(model, metrics, config, tag="val"):
    with_amp = config["with_amp"]
    device = idist.device()

    @torch.no_grad()
    def evaluate_step(engine: Engine, batch):
        model.eval()
        x, y = batch[0], batch[1]
        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        with autocast(enabled=with_amp):
            output = model(x)
        return output, y

    evaluator = Engine(evaluate_step)

    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    if idist.get_rank() == 0 and (not config["with_clearml"]):
        common.ProgressBar(desc=f"Evaluation ({tag})", persist=False).attach(evaluator)

    return evaluator

def get_save_handler(config):
    if config["with_clearml"]:
        from ignite.contrib.handlers.clearml_logger import ClearMLSaver

        return ClearMLSaver(dirname=config["output_path"])

    return DiskSaver(config["output_path"], require_empty=False)

The only thing I could not work out was using persistent_workers=True with idist.auto_dataloader:

TypeError: __init__() got an unexpected keyword argument 'persistent_workers')

Is there any way of using persistent workers with auto_dataloader?

Glad that you could make it work.

Oh, I just realized that persistent_workers was not available for pytorch 1.6.0 and appears in the docs since 1.7.0. As idist.auto_dataloader just pass through the args and pytorch DataLoader does not have them, that’s why there is an issue.

1 Like

Right, makes sense :slight_smile:

Thank you very much for the help, and let me say Ignite is awesome :fire:

1 Like

One other thing I noticed is that either num_workers and nproc_per_node must be the same or num_workers must set 0 for the code to run. Do you know the reason behind this?

Example in point: the code runs with num_workers=8 and nproc_per_node=8 and num_workers=0 and nproc_per_node=8 but not with num_workers=10/9/7/6 (any of these) and nproc_per_node=8 (DataLoader worker (pid(s) 53785) exited unexpectedly.).

I’m using persistent_workers=True when running with num_workers>0 and persistent_workers=False with num_workers=0.

Interesting. I’ll try on my infrastructure to see. On the one hand, auto_dataloader scales num_workers per process by the number of local processes: auto_dataloader — PyTorch-Ignite v0.4.6 Documentation

        if "num_workers" in kwargs and kwargs["num_workers"] >= nproc:
            kwargs["num_workers"] = (kwargs["num_workers"] + nproc - 1) // nproc

You can see the number of workers the logger is reporting. If num_workers > 8, then it is scaled by nproc_per_node. Otherwise, it will use the number itself. Probable, if you set num_workers=1, it should run as well.

On the other hand, each process of 8 ( nproc_per_node) on the machine will create a dataloader which will also spawn workers. Maybe, the system can not spawn that much ?

Sometimes, it can help to increase shm size if running in docker container with the flag --shm-size 16G

EDIT: Updated Colab to run resnet18 on cifar10 with 8 procs and 10 workers, seems like running.

Interesting, didn’t know about the scaling of num_workers.

I just peeked into your Colab and it is clear it is working that way. Probably what was happening is that, as long as scaling is on (when num_workers is smaller than 8) the system works as num_workers is low but when scaling is off, after passing a certain threshold (in this case 8), it can’t handle that many workers. This threshold (max num_workers) probably varies depending on the instance.

Will be experimenting to find the sweet spot for my specific case. Thanks!