Is it possible to use DistributedDataParalel inside of a class?

Hi there!

I’m using PyTorch for gradient based inverse problem solving. My current architecture utilizes custom Trainer class which takes model, dataset, optimizer, regularization function …etc. and handles all the training/optimization in its method. It currently wraps the model into the DataParallel object to utilize multiple available GPUs.

Now, I want to switch it to the DistributedDataParallel to further speed up the process and maybe scale it to multiple nodes. However, I’ve figured out that I cant do it in the class method, since DistributedDataParallel requires the process group creation, which can be don only in __main__ and I’m basically need to write separate train script to use the DistributedDataParallel?

Am I understood this correctly ? ( I am not super experienced with torch.multiprocessing)
Are there maybe some ways to use per-created processes group and/or DistributedDataParallel objects to unify the train process with DistributedDataParallel and DataParallel.

Thanks for your answer.

In general we recommend using DistributedDataParallel over DataParallel. Yes, the setup for DistributedDataParallel is indeed different from DataParallel since it requires setting up process group and also multiple processes. My recommendation would be to just have one version of your training script and only use DistributedDataParallel.

Hi @Konsthr,

There is no restriction where DDP is stated. Whether using DDP in a class or outside is not a problem.
You can refer to several examples including mine.
The following functions are related. - setup
model/ - parallelize

p.s. My implementation allows both DP and DDP.