Passing model, dataset, optimizer etc. into DDP

Hi, currently I have the following train function:

def train(model, dataset, sampler, criterion, optimizer, scheduler, cfg):
    model = DataParallel(model.cuda())

    loader = DataLoader(dataset, bs=cfg.BS, num_workers=4, sampler=sampler)

    for epoch_idx in range(cfg.EPOCHS):
        for batch, targets in loader:
            preds = model(batch)
            loss = criterion(preds, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

How can I convert this code to use DistributeDataParallel?

I looked in the tutorials, but they initialise model, dataset, etc. inside the train function. I can’t do that since I want the signature to remain the same and have the flexibility of defining model, dataset, etc. outside the train function.

Can I just pass all that using arguments in mp.spawn?

Hey @Rizhiy

You might not be able to pass the optimizer as that, because every subprocess needs its own dedicated optimizer. And not sure about how dateset/criterion/scheduler in the code behave in multiprocessing use cases. If it is just the model, the following might work.

def train(rank, model):
    model = DistributedDataParallel(model.to(rank), device_ids=[rank], output_device=rank)
    ...

def main():
    model.share_memory() 
    mp.spawn(
        train, 
        args=(model, dataset, ...),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

Rank is the subprocess id, which is provided my mp.spawn as the first argument to the target function.

If the reason for this is to keep the train() signature intact, is it possible to create another wrapper function to wrap train(), configure everything in wrapper, call train in wrapper, and use wrapper as the spawn target?

Hi @mrshenli, the reason is that this function is part of the internal framework and other programmers should be able to setup the arguments how they require and train() is only responsible for the loop.

I guess my main question is: what can be passed as arguments in mp.spawn()?

IIUC, it can accept shared memory tensor and anything that’s pickable with Python multiprocessing