I’m using PyTorch for gradient based inverse problem solving. My current architecture utilizes custom Trainer class which takes model, dataset, optimizer, regularization function …etc. and handles all the training/optimization in its method. It currently wraps the model into the DataParallel object to utilize multiple available GPUs.
Now, I want to switch it to the DistributedDataParallel to further speed up the process and maybe scale it to multiple nodes. However, I’ve figured out that I cant do it in the class method, since DistributedDataParallel requires the process group creation, which can be don only in __main__ and I’m basically need to write separate train script to use the DistributedDataParallel?
Am I understood this correctly ? ( I am not super experienced with torch.multiprocessing)
Are there maybe some ways to use per-created processes group and/or DistributedDataParallel objects to unify the train process with DistributedDataParallel and DataParallel.
Thanks for your answer.