I am trying to train MaskRCNN (class from Torchvision) on 3 machines (master, worker-1 and worker-2) (also multiple GPUs per machine) using distributed RPC & with model parallelism. So far I was able to divide the whole architecture onto 2 machines (manually). So far I have used below resources extensively to do this,
Finetuning MaskRCNN in general: TorchVision Object Detection Finetuning Tutorial — PyTorch Tutorials 1.7.1 documentation
Distributed RPC: Distributed Pipeline Parallelism Using RPC — PyTorch Tutorials 1.7.1 documentation
Torchvision github
Now only thing left is to write the training loop, where I am stuck. I have observed the training loop used in this. And I need to convert this code to distributed training loop. How should I edit https://github.com/pytorch/vision/tree/master/references/detection/engine.py (Which contains training loop for training mask-rcnn as linked previously) to achieve this? Also training loop for distributed rpc is also given in this example (As mentioned above) for classification task. How to combine these 2 to train mask rcnn in distributed way (with model parallelism)?
EDIT :
Question-2:
Second shard of the model which resides on the second machine, has 2 modules. Each of them is allocated 1 GPU on that machine (meaning module 1 on GPU-1, module 2 on GPU-2). And I have other 2 GPUs on the same machine, which are ideal. Dividing these 2 modules on 4 GPUs equally mean overriding their forward methods, which are very complicated (Talking about RPN and RoI Heads module). Can I do something like this, put module-1 on GPU-1 & GPU-2, module-2 on GPU-3 & GPU-4, now when input batch comes split that equally into 2 parts, process them like this, part1 → module-1 (GPU:1) → module-2 (GPU:3) & part2 → module-1 (GPU:2) → module-2 (GPU:4) (of course, these 2 workflows will run concurrently, at the end their losses will be averaged and avg. loss will be returned to master node). I have already referred to following articles,
https://pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html
https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html
Is it possible to perform above explained case in PyTorch?