I’m building a Trainer class that supports pytorch DDP(multi-gpu training). The structure is like this
class trainer:
def __init__(self, ...):
# initialize once here
self.train_loader = ...
self.model = ....
self.optimizer = ...
def train_loop(self, rank, world_size):
# transfer to current rank
self.model.to(rank)
for data in train_loader:
data = data.to(rank)
def run(self):
# spawn processes
world_size = ...
torch.multiprocessing.spawn(self.train_loop,
args=(world_size,),
nprocs=world_size, join=True)
So the above code works. But no syntax error doesn’t mean that it’s doing what it’s supposed to do, especially with DDP.
Question:
- Is the approach correct? I think that, unlike threading, multi-processing will spawn independent processes and all of them will have their copy of the trainer object. Is that correct?