Manually reshard FSDP module after OOM

Hi PyTorch friends, is there a safe API that I can call to manually reshard FSDP ?

Context:
We’re trying an batch size auto-tuning idea. We use FSDP and start with a large batch size. We’ll catch the OOM exception, and try with smaller batch size. We repeat this until find a batch size that won’t OOM.

it’s like this:

While True:
      try:
             train_one_batch(fsdp_model, input_data)
      except CUDA_OOM:
             # reduce batch size
             # i need to re-shard some of the fdsp modules manually here?
             ....
             continue
      break

But we found that, when we trying with smaller batch size after OOM, it seems GPU memory usage is higher than the beginning when we just initilized the model. I guess when OOM happens, it’s possible that some FSDP module is still in unsharded state, we need manually shard them? If my hypothesis is right, is there already a safe API i can use to shard the FSDP before training with samller batch size?

cc @agu , @wconstab could you help take a look? thanks!

It looks i can try to call this _post_forward() pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py at main · pytorch/pytorch · GitHub . Not sure if there is a more appropriate way

actually this callback is the API to do the final reshard + cleanup .

@agu , could we make it a public API. I think we need a public API to cleanup fsdp to avoid memory leak after exception.

it can be re-produced by this unit test:

def _fsdp_reshard_and_cleanup(model: torch.nn.Module):
    for name, module in model.named_modules():
        if isinstance(module, FullyShardedDataParallel):
            if module.check_is_root():
                try:
                    _post_backward_final_callback(module, module)
                except Exception as e:
                    log.warning(f'Failed to reshard fsdp after oom, error: {e}')

class SimpleMLPForTestingOOM(torch.nn.Module):

    def __init__(self, num_features: int = 128, device: str = 'cuda'):
        super().__init__()
        self.device = device
        self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
        self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
        self.fc3 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
        self.rank = dist.get_global_rank()

        def oom_hook(*args):
            raise RuntimeError('CUDA out of memory.')

        self.fc2.register_full_backward_hook(oom_hook)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

    def loss(self, outputs, batch):
        return torch.sum(outputs)


@pytest.mark.gpu
@world_size(2)
def test_fsdp_reshard_after_oom(world_size: int):
    model = SimpleMLPForTestingOOM()

    fsdp_model = FSDP(trainer.state.model)

    x = torch.rand([2, 128])
    output = fsdp_model(x)
    with pytest.raises(Exception):
        # Backward triggers the fake OOM exception,
        # which prevents fsdp reshard and cleanup
        torch.sum(output).backward()

    fc2_flat_param = fsdp_model.fc2._flat_param

    # without cleanup, model.fc2.flat_params is still in unshard state
    # the full param is not freed
    assert fc2_flat_param.data_ptr() != fc2_flat_param._local_shard.data_ptr()
    assert fc2_flat_param._full_param_padded.numel() > 0

    _fsdp_reshard_and_cleanup(fsdp_model)
    assert fc2_flat_param.data_ptr() == fc2_flat_param._local_shard.data_ptr()