Hi there!
I’m attempting to use FSDP for medical image segmentation to reduce GPU memory footprint during training. As a starting point, I’m trying to adapt this tutorial from MONAI to use FSDP with 2 GPUs.
I’ve created a fork of the original tutorial that instead spawns 2 processes, each of which begins a training loop with a module wrapped with FSDP. I had to split the fsdp_main
function out into it’s own file due to this error.
from fsdp_main import fsdp_main
WORLD_SIZE = 2
mp.spawn(fsdp_main, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True)
I’m currently seeing this printed error message:
Expects tensor to be on the compute device cuda:1
And the following stack trace:
ProcessRaisedException:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/baclark/code/Project-MONAI-tutorials/3d_segmentation/fsdp_main.py", line 240, in fsdp_main
outputs = model(inputs)
File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 2720, in forward
self._pre_forward(self._handles, unshard_fn, unused, unused)
File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 2756, in _pre_forward
unshard_fn()
File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 2768, in _pre_forward_unshard
self._unshard(handles)
File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1566, in _unshard
handle.post_unshard()
File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 768, in post_unshard
self._check_on_compute_device(self.flat_param)
File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 1064, in _check_on_compute_device
p_assert(
File "/home/baclark/venvs/medical-imaging/lib/python3.8/site-packages/torch/distributed/fsdp/_utils.py", line 149, in p_assert
raise AssertionError
AssertionError
This seems to be occurring when the FSDP module unshards the parameters before performing the forward pass. For some reason this particular FlatParamHandle
has been unsharded onto a different device than is stored in it’s self.device
attribute.
I have checked that all input/label data and model parameters are on the correct devices before performing the forward pass.
The problem should be reproducible if you run the forked code with the following versions and update this directory to which the segmentation data is downloaded.
Any ideas why there might be a device mismatch for these flattened parameters?
CUDA 11.7
python 3.8.6
monai-weekly 1.2.dev2252
torch 1.13.1
torchaudio 0.13.1
torchmetrics 0.11.0
torchvision 0.14.1
Thanks,
Brett