when I use DistributedDataParallel to parallel training, if the model must implement and call forward function, for example
class ToyMpModel(nn.Module):
def __init__(self, dev0, dev1):
super(ToyMpModel, self).__init__()
self.dev0 = dev0
self.dev1 = dev1
self.net1 = torch.nn.Linear(10, 10).to(dev0)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5).to(dev1)
def forward(self, x):
x = x.to(self.dev0)
x = self.relu(self.net1(x))
x = x.to(self.dev1)
return self.net2(x)
def encoder(self, x):
x = x.to(self.dev0)
x = self.relu(self.net1(x))
x = x.to(self.dev1)
return self.net2(x)
def demo_model_parallel(rank, world_size):
print(f"Running DDP with model parallel example on rank {rank}.")
setup(rank, world_size)
# setup mp_model and devices for this process
dev0 = (rank * 2) % world_size
dev1 = (rank * 2 + 1) % world_size
mp_model = ToyMpModel(dev0, dev1)
ddp_mp_model = DistributedDataParallel (mp_model)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
optimizer.zero_grad()
# outputs will be on dev1
#outputs = ddp_mp_model(torch.randn(20, 10))
outputs = ddp_mp_model.module.encoder(torch.randn(2, 10))
print(outputs)
labels = torch.randn(2, 5).to(dev1)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
ToyMpModel has two function encoder and forward with the same codes, when working with DistributedDataParallel, will outputs = ddp_mp_model.module.encoder(torch.randn(2, 10))
be work correctly, parameter in different gpu will synchronize with for example all-reduce