Model in DistributedDataParallel must implement and call forward funciton

when I use DistributedDataParallel to parallel training, if the model must implement and call forward function, for example

class ToyMpModel(nn.Module):
    def __init__(self, dev0, dev1):
        super(ToyMpModel, self).__init__()
        self.dev0 = dev0
        self.dev1 = dev1
        self.net1 = torch.nn.Linear(10, 10).to(dev0)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to(dev1)

    def forward(self, x):
        x = x.to(self.dev0)
        x = self.relu(self.net1(x))
        x = x.to(self.dev1)
        return self.net2(x)

    def encoder(self, x):
        x = x.to(self.dev0)
        x = self.relu(self.net1(x))
        x = x.to(self.dev1)
        return self.net2(x)


def demo_model_parallel(rank, world_size):
    print(f"Running DDP with model parallel example on rank {rank}.")
    setup(rank, world_size)

    # setup mp_model and devices for this process
    dev0 = (rank * 2) % world_size
    dev1 = (rank * 2 + 1) % world_size
    mp_model = ToyMpModel(dev0, dev1)
    ddp_mp_model = DistributedDataParallel (mp_model)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    # outputs will be on dev1
    #outputs = ddp_mp_model(torch.randn(20, 10))
    outputs = ddp_mp_model.module.encoder(torch.randn(2, 10))
    print(outputs)
    labels = torch.randn(2, 5).to(dev1)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()

ToyMpModel has two function encoder and forward with the same codes, when working with DistributedDataParallel, will outputs = ddp_mp_model.module.encoder(torch.randn(2, 10)) be work correctly, parameter in different gpu will synchronize with for example all-reduce

No, I don’t think your usage is will work correctly, since you are explicitly calling the internal .module manually and are thus skipping the DDP wrapper.

so, i must wrap the codes into farword functions? DDP has a farword funcion, I am not sure this funcion has something with model’s forward