[DDP] Train different layers of one model separately

Hi all,

I have a model that I want to train using DDP in a way where I have two sets of workers W_1 and W_2. Now a worker from W_1 shall update only the layers L_1, while a worker from W_1 shall update only the layers W_2. The set of layers L_1 and L_2 are disjoint. How can I achieve this?

Let see an example:

I have a model

class DummyModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.conv2 = nn.Conv2d(1, 1, 1)

    def forward(self, x):
        return self.conv2(self.conv1(x))

Now each worker from W_1 shall optimise the weights of MyModel.conv1 and each worker from W_2 shall optimise the weights of MyModel.conv2.

Consider the following pseudo-code for two workers (with rank = 0 and rank = 1):

my_model = MyModel()
opt_w1 = torch.optim.Adam(my_model.conv1.parameters())
opt_w2 = torch.optim.Adam(my_model.conv2.parameters())
if rank == 0:
    opt_w1.zero_grad()
    opt_w2.zero_grad()
    loss = my_model(input_batch, target)
    loss.backward()
    # sync and update only conv1
    opt_w1.step()
elif rank == 1:
    opt_w2.zero_grad()
    opt_w1.zero_grad()
    loss = my_model(input_batch, target)
    loss.backward()
    # sync and update only conv2
    opt_w2.step()

The pseudo-code is not correct however. The backward() computes the gradient for all weights. It would sum up all gradients and divide by the number |W_1|+|W_2|. To be correct according to the task, we need to sum up only the gradients from the workers W_1 and divide by |W_1| and broadcast, likewise sum up only the gradients from the workers W_2 and divide by |W_2| and broadcast.

How can I do this?

(to give you some motivation, I have a network that works with different types of input, e.g. RGB images and grayscale images, and I know for each type of input, which layers I should optimise).

Your code won’t work assuming you are using DDP since you are diverging the models. Model parameters are only initially shared and DDP depends on the gradient synchronization as well as the same parameter update to keep all models equal. In your example you are explicitly updating different parts of the model depending on the rank and will thus create different parameter sets, outputs, losses, gradients.

Thank you for the detailed answer!