FSDP with MONAI tutorial

Hi there!

I’m attempting to use FSDP for medical image segmentation to reduce GPU memory footprint during training. As a starting point, I’m trying to adapt this tutorial from MONAI to use FSDP with 2 GPUs.

I’ve created a fork of the original tutorial that instead spawns 2 processes, each of which begins a training loop with a module wrapped with FSDP. I had to split the fsdp_main function out into it’s own file due to this error.

from fsdp_main import fsdp_main
WORLD_SIZE = 2
mp.spawn(fsdp_main, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True)

I’m currently seeing this printed error message:

Expects tensor to be on the compute device cuda:1

And the following stack trace:

ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/baclark/code/Project-MONAI-tutorials/3d_segmentation/fsdp_main.py", line 240, in fsdp_main
    outputs = model(inputs)
  File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 2720, in forward
    self._pre_forward(self._handles, unshard_fn, unused, unused)
  File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 2756, in _pre_forward
    unshard_fn()
  File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 2768, in _pre_forward_unshard
    self._unshard(handles)
  File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1566, in _unshard
    handle.post_unshard()
  File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 768, in post_unshard
    self._check_on_compute_device(self.flat_param)
  File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 1064, in _check_on_compute_device
    p_assert(
  File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/_utils.py", line 149, in p_assert
    raise AssertionError
AssertionError

This seems to be occurring when the FSDP module unshards the parameters before performing the forward pass. For some reason this particular FlatParamHandle has been unsharded onto a different device than is stored in it’s self.device attribute.

I have checked that all input/label data and model parameters are on the correct devices before performing the forward pass.

The problem should be reproducible if you run the forked code with the following versions and update this directory to which the segmentation data is downloaded.

Any ideas why there might be a device mismatch for these flattened parameters?

CUDA                            11.7
python                          3.8.6
monai-weekly                    1.2.dev2252
torch                           1.13.1
torchaudio                      0.13.1
torchmetrics                    0.11.0
torchvision                     0.14.1

Thanks,
Brett