PyTorch FSDP Tutorial not working on torch 1.12

Hello Merry Christmas for all of you:)
I’m currently testing PyTorch FSDP Tutorials

I’ve succeeding running the first tutorial. However while running the second script which is handling huggingface T5 block, I’ve got the following errors

Python 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10)
[GCC 10.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import MixedPrecision
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: cannot import name 'MixedPrecision' from 'torch.distributed.fsdp' (/opt/conda/lib/python3.8/site-packages/torch/distributed/fsdp/__init__.py)
>>> from torch.distributed.fsdp import BackwardPrefetch
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: cannot import name 'BackwardPrefetch' from 'torch.distributed.fsdp' (/opt/conda/lib/python3.8/site-packages/torch/distributed/fsdp/__init__.py)
>>> from torch.distributed.fsdp import ShardingStrategy
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: cannot import name 'ShardingStrategy' from 'torch.distributed.fsdp' (/opt/conda/lib/python3.8/site-packages/torch/distributed/fsdp/__init__.py)
>>> from torch.distributed.fsdp import FullStateDictConfig
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: cannot import name 'FullStateDictConfig' from 'torch.distributed.fsdp' (/opt/conda/lib/python3.8/site-packages/torch/distributed/fsdp/__init__.py)
>>> from torch.distributed.fsdp import StateDictType
>>>

It seems that MixedPrecision, BackwardPrefetch, ShardingStrategy, FullStateDictConfig cannot be imported.

Could you please check this issue? Thank you beforehand

*My Env

  • docker image: nvcr.io/nvidia/pytorch:22.04-py3
  • python3 env
python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 1.12.0a0+bd13bc6
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.22.3
Libc version: glibc-2.31

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10)  [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-4.19.93-1.nbp.el7.x86_64-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.6.124
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB

Nvidia driver version: 515.65.07
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.4.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] pytorch-quantization==2.1.2
[pip3] torch==1.12.0a0+bd13bc6
[pip3] torch-tensorrt==1.1.0a0
[pip3] torchtext==0.13.0a0
[pip3] torchvision==0.13.0a0
[conda] magma-cuda110             2.5.2                         5    local
[conda] mkl                       2019.5                      281    conda-forge
[conda] mkl-include               2019.5                      281    conda-forge
[conda] numpy                     1.22.3           py38h1d589f8_2    conda-forge
[conda] pytorch-quantization      2.1.2                    pypi_0    pypi
[conda] torch                     1.12.0a0+bd13bc6          pypi_0    pypi
[conda] torch-tensorrt            1.1.0a0                  pypi_0    pypi
[conda] torchtext                 0.13.0a0                 pypi_0    pypi
[conda] torchvision               0.13.0a0                 pypi_0    pypi

It works for me in 1.12.1:

>>> import torch
>>> torch.__version__
'1.12.1+cu116'
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> FullStateDictConfig
<class 'torch.distributed.fsdp.fully_sharded_data_parallel.FullStateDictConfig'>

NGC 22.04 uses a pre-1.12 commit and thus doesn’t have this class as seen here, so update to a more recent container or wheels.

2 Likes