Question about the use of ddp_model.no_sync() with torch.utils.checkpoint.checkpoint
Background: The ddp_model is a TwoLinearLayerMLP
If I use torch.utils.checkpoint.checkpoint with the whole ddp_model:
Code1
with ddp_model.no_sync():
for i in range(1, micro_batch_size):
output = cp.checkpoint(ddp_model, inputs[i])
loss = loss_fn(output, labels[i])
loss.backward()
output = cp.checkpoint(ddp_model, inputs[0])
loss = loss_fn(output, labels[0])
loss.backward()
Code2
outputs = [0] * micro_batch_size
outputs[0] = cp.checkpoint(ddp_model, inputs[0])
with ddp_model.no_sync():
for i in range(1, micro_batch_size):
outputs[i] = cp.checkpoint(ddp_model, inputs[i])
for i in range(1, micro_batch_size):
loss = loss_fn(outputs[i], labels[i])
loss.backward()
loss = loss_fn(outputs[0], labels[0])
loss.backward()
Both of the above codes can work properly.
If I use torch.utils.checkpoint.checkpoint with the two linear layers that belong to the TwoLinearLayerMLP, like this:
Code3
def forward(self, input):
a1 = functional.relu(cp.checkpoint(self.linear1, input))
a2 = cp.checkpoint(self.linear2, a1)
return input + a2
This code can still work properly:
Code4
with ddp_model.no_sync():
for i in range(1, micro_batch_size):
output = ddp_model(inputs[i])
loss = loss_fn(output, labels[i])
loss.backward()
output = ddp_model(inputs[0])
loss = loss_fn(output, labels[0])
loss.backward()
But this code works wrongly. Gradients across DDP processes can’t be synchronized:
Code5
outputs = [0] * micro_batch_size
outputs[0] = ddp_model(inputs[0])
with ddp_model.no_sync():
for i in range(1, micro_batch_size):
outputs[i] = ddp_model(inputs[i])
for i in range(1, micro_batch_size):
loss = loss_fn(outputs[i], labels[i])
loss.backward()
loss = loss_fn(outputs[0], labels[0])
loss.backward()
My questions:
- I want to let the Code5 work properly. How can I solve this problem? And there are any matters needing attention when I use ddp_model.no_sync() with torch.utils.checkpoint.checkpoint.
- What’s the difference between Code1 and Code2? I want to use ddp_model.no_sync() like Code2.