### 🐛 Describe the bug
Currently, when using FSDP, the model is loaded for ea…ch of the N processes completely on CPU leading to huge CPU RAM usage. When training models like Flacon-40B with FSDP on a dgx node with 8 GPUs, it would lead to CPU RAM getting out of memory because each process is loading 160GB (40B x 4Bytes (FP32)) in CPU RAM for a total of 160*8=1280GB requirement which results in script getting killed due to out of CPU RAM.
To combat this, we are trying to load the model only on rank 0 and have it on `meta` device when rank!=0. Then use `param_init_fn` along with `sync_module_states=True` for FSDP to properly init the weights on other ranks and broadcast the params from rank 0 to other ranks. **This is trying to achieve what `zero.init()` from DeepSpeed does. it would be great for FSDP too to support this out of the box**
However, when using above approach, the metrics in terms of accuracy and F1 scores are random, ie., the model isn;t learning anything even though the weights seem to change and train loss seems to decrease a little bit.
Code: https://github.com/pacman100/ram_efficient_fsdp
Steps:
1. pip install -r requirements.txt
2. bash run.sh
The FSDP config is in `config.yaml`
```
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BertLayer
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
Output:
```
[2023-07-24 16:34:14,815] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-07-24 16:34:19,709] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-07-24 16:34:19,736] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
DistributedType.FSDP
wandb: Currently logged in as: smangrul. Use `wandb login --relogin` to force relogin
Found cached dataset glue (/raid/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 1269.59it/s]
wandb: Tracking run with wandb version 0.15.5
wandb: Run data is saved locally in /home/sourab/ram_efficient_fsdp/wandb/run-20230724_163421-hshg5m1t
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run whole-fire-12
wandb: ⭐️ View project at https://wandb.ai/smangrul/fsdp_glue_no_trainer
wandb: 🚀 View run at https://wandb.ai/smangrul/fsdp_glue_no_trainer/runs/hshg5m1t
Found cached dataset glue (/raid/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 1229.52it/s]
Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-c8be3b6cea2b5568.arrow
Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-98f28c5bb15f064c.arrow
accelerator.process_index=1 model.bert.pooler.dense.weight=Parameter containing:
tensor(..., device='meta', size=(768, 768), requires_grad=True)
accelerator.process_index=1 model.classifier.weight=Parameter containing:
tensor(..., device='meta', size=(2, 768), requires_grad=True)
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
accelerator.process_index=0 model.bert.pooler.dense.weight=Parameter containing:
tensor([[-0.0013, -0.0381, -0.0158, ..., 0.0244, -0.0008, 0.0240],
[ 0.0020, 0.0151, 0.0033, ..., 0.0180, -0.0023, 0.0231],
[-0.0386, 0.0145, 0.0621, ..., 0.0374, -0.0105, -0.0395],
...,
[-0.0111, 0.0136, 0.0541, ..., 0.0666, 0.0017, -0.0090],
[ 0.0001, 0.0024, -0.0125, ..., 0.0046, -0.0014, -0.0079],
[ 0.0415, 0.0751, 0.0305, ..., 0.0317, 0.0479, 0.0080]],
requires_grad=True)
accelerator.process_index=0 model.classifier.weight=Parameter containing:
tensor([[-0.0025, 0.0011, -0.0052, ..., -0.0212, 0.0227, 0.0206],
[ 0.0151, -0.0045, 0.0243, ..., -0.0208, -0.0183, -0.0203]],
requires_grad=True)
FullyShardedDataParallel(
(_fsdp_wrapped_module): BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x FullyShardedDataParallel(
(_fsdp_wrapped_module): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=2, bias=True)
)
)
accelerator.process_index=0 model.bert.pooler.dense.weight=Parameter containing:
tensor([[-0.0013, -0.0381, -0.0158, ..., 0.0244, -0.0008, 0.0240],
[ 0.0020, 0.0151, 0.0033, ..., 0.0180, -0.0023, 0.0231],
[-0.0386, 0.0145, 0.0621, ..., 0.0374, -0.0105, -0.0395],
...,
[-0.0111, 0.0136, 0.0541, ..., 0.0666, 0.0017, -0.0090],
[ 0.0001, 0.0024, -0.0125, ..., 0.0046, -0.0014, -0.0079],
[ 0.0415, 0.0751, 0.0305, ..., 0.0317, 0.0479, 0.0080]],
device='cuda:0', requires_grad=True)
accelerator.process_index=0 model.classifier.weight=Parameter containing:
tensor([[-0.0025, 0.0011, -0.0052, ..., -0.0212, 0.0227, 0.0206],
[ 0.0151, -0.0045, 0.0243, ..., -0.0208, -0.0183, -0.0203]],
device='cuda:0', requires_grad=True)
accelerator.process_index=1 model.bert.pooler.dense.weight=Parameter containing:
tensor([[-0.0013, -0.0381, -0.0158, ..., 0.0244, -0.0008, 0.0240],
[ 0.0020, 0.0151, 0.0033, ..., 0.0180, -0.0023, 0.0231],
[-0.0386, 0.0145, 0.0621, ..., 0.0374, -0.0105, -0.0395],
...,
[-0.0111, 0.0136, 0.0541, ..., 0.0666, 0.0017, -0.0090],
[ 0.0001, 0.0024, -0.0125, ..., 0.0046, -0.0014, -0.0079],
[ 0.0415, 0.0751, 0.0305, ..., 0.0317, 0.0479, 0.0080]],
device='cuda:1', requires_grad=True)
accelerator.process_index=1 model.classifier.weight=Parameter containing:
tensor([[-0.0025, 0.0011, -0.0052, ..., -0.0212, 0.0227, 0.0206],
[ 0.0151, -0.0045, 0.0243, ..., -0.0208, -0.0183, -0.0203]],
device='cuda:1', requires_grad=True)
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Memory before entering the train : 212
Memory consumed at the end of the train (end-begin): 659
Peak Memory consumed during the train (max-begin): 1995
Total Peak Memory consumed during the train (max): 2207
epoch 0: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
Memory before entering the eval : 872
Memory consumed at the end of the eval (end-begin): 94
Peak Memory consumed during the eval (max-begin): 209
Total Peak Memory consumed during the eval (max): 1081
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
Memory before entering the train : 966
Memory consumed at the end of the train (end-begin): -94
Peak Memory consumed during the train (max-begin): 1218
Total Peak Memory consumed during the train (max): 2184
epoch 1: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
Memory before entering the eval : 872
Memory consumed at the end of the eval (end-begin): 94
Peak Memory consumed during the eval (max-begin): 209
Total Peak Memory consumed during the eval (max): 1081
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
Memory before entering the train : 966
Memory consumed at the end of the train (end-begin): -94
Peak Memory consumed during the train (max-begin): 1297
Total Peak Memory consumed during the train (max): 2263
epoch 2: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
Memory before entering the eval : 872
Memory consumed at the end of the eval (end-begin): 94
Peak Memory consumed during the eval (max-begin): 209
Total Peak Memory consumed during the eval (max): 1081
wandb: Waiting for W&B process to finish... (success).
wandb:
wandb: Run history:
wandb: accuracy ▁▁▁
wandb: eval_total_peak_memory ▁▁▁
wandb: f1 ▁▁▁
wandb: train_loss █▂▁
wandb: train_total_peak_memory ▃▁█
wandb:
wandb: Run summary:
wandb: accuracy 0.68382
wandb: eval_total_peak_memory 1081
wandb: f1 0.81223
wandb: train_loss 0.63513
wandb: train_total_peak_memory 2263
wandb:
wandb: 🚀 View run whole-fire-12 at: https://wandb.ai/smangrul/fsdp_glue_no_trainer/runs/hshg5m1t
wandb: Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20230724_163421-hshg5m1t/logs
```
As you can see, the performance remains same across the 3 epochs which is random performance.
**Expected Behaviour:**
Model learns when using `param_init_fn` and `sync_module_states=True` with FSDP so that the pretrained model can be loaded only on rank_0 and it can be `meta` for rank!=0. This is required for FSDP to be usable with large models in practice.
### Versions
PyTorch version: 2.0.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.26.4
Libc version: glibc-2.31
Python version: 3.10.11 (main, May 16 2023, 00:28:57) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-125-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA DGX Display
GPU 4: NVIDIA A100-SXM4-80GB
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] lion-pytorch==0.0.6
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] pytorch-lightning==1.9.0
[pip3] pytorch-triton==2.1.0+9e3e10c5ed
[pip3] torch==2.0.1+cu118
[pip3] torchaudio==2.0.2+cu118
[pip3] torchmetrics==0.11.4
[pip3] torchvision==0.15.2+cu118
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] lion-pytorch 0.0.6 pypi_0 pypi
[conda] mkl 2023.1.0 h6d00ec8_46342
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.6 py310h1128e8f_1
[conda] mkl_random 1.2.2 py310h1128e8f_1
[conda] numpy 1.24.4 pypi_0 pypi
[conda] numpy-base 1.25.0 py310hb5e798b_0
[conda] pytorch-cuda 11.8 h7e8668a_5 pytorch
[conda] pytorch-lightning 1.9.0 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] pytorch-triton 2.1.0+9e3e10c5ed pypi_0 pypi
[conda] torch 2.0.1+cu118 pypi_0 pypi
[conda] torchaudio 2.0.2+cu118 pypi_0 pypi
[conda] torchmetrics 0.11.4 pypi_0 pypi
[conda] torchvision 0.15.2+cu118 pypi_0 pypi
cc @zhaojuanmao @mrshenli @rohan-varma @awgu