Question about the use of ddp_model.no_sync() with torch.utils.checkpoint.checkpoint

Question about the use of ddp_model.no_sync() with torch.utils.checkpoint.checkpoint

Background: The ddp_model is a TwoLinearLayerMLP

If I use torch.utils.checkpoint.checkpoint with the whole ddp_model:

Code1

with ddp_model.no_sync():
    for i in range(1, micro_batch_size):
        output = cp.checkpoint(ddp_model, inputs[i])
        loss = loss_fn(output, labels[i])
        loss.backward()
output = cp.checkpoint(ddp_model, inputs[0])
loss = loss_fn(output, labels[0])
loss.backward()

Code2

outputs = [0] * micro_batch_size
outputs[0] = cp.checkpoint(ddp_model, inputs[0])
with ddp_model.no_sync():
    for i in range(1, micro_batch_size):
        outputs[i] = cp.checkpoint(ddp_model, inputs[i])
    for i in range(1, micro_batch_size):    
        loss = loss_fn(outputs[i], labels[i])
        loss.backward()
loss = loss_fn(outputs[0], labels[0])
loss.backward()

Both of the above codes can work properly.

If I use torch.utils.checkpoint.checkpoint with the two linear layers that belong to the TwoLinearLayerMLP, like this:

Code3

def forward(self, input):
      a1 = functional.relu(cp.checkpoint(self.linear1, input))
      a2 = cp.checkpoint(self.linear2, a1)
      return input + a2

This code can still work properly:
Code4

with ddp_model.no_sync():
      for i in range(1, micro_batch_size):
          output = ddp_model(inputs[i])
          loss = loss_fn(output, labels[i])
          loss.backward()
  output = ddp_model(inputs[0])
  loss = loss_fn(output, labels[0])
  loss.backward()

But this code works wrongly. Gradients across DDP processes can’t be synchronized:
Code5

outputs = [0] * micro_batch_size
outputs[0] = ddp_model(inputs[0])
with ddp_model.no_sync():
    for i in range(1, micro_batch_size):
        outputs[i] = ddp_model(inputs[i])
    for i in range(1, micro_batch_size):
        loss = loss_fn(outputs[i], labels[i])
        loss.backward()
loss = loss_fn(outputs[0], labels[0])
loss.backward()

My questions:

  1. I want to let the Code5 work properly. How can I solve this problem? And there are any matters needing attention when I use ddp_model.no_sync() with torch.utils.checkpoint.checkpoint.
  2. What’s the difference between Code1 and Code2? I want to use ddp_model.no_sync() like Code2.