FSDP doesn't reduce the GPU memory usage

hi, I’m working on refactor the deit code to integrate FSDP to train deit_large with larger single gpu batch.
After import related FSDP module, I simply change my code [ref this link] to

if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model = FSDP(
            model,
            fsdp_auto_wrap_policy=default_auto_wrap_policy,
        )

then I hit two problems:

  1. the fsdp model cannot run with model EMA.
  2. as metioned above, the GPU memory usage is not reduced as expected.

Here is an example for Q2:
run with FSDP:

GPU memory by nvidia-smi: 
18191MiB
terminal output by deit:
Epoch: [0]  [  200/40036]  eta: 10:42:06  lr: 0.000001  loss: 0.5857 (0.6690)  
time: 1.0125  data: 0.0036  max mem: 13402

run without FSDP:

GPU memory by nvidia-smi: 
19569MiB
terminal output by deit: 
Epoch: [0]  [ 200/5004]  eta: 1:47:13  lr: 0.000001  loss: 0.7104 (0.7275)  
time: 1.1631  data: 0.0002  max mem: 17334

Both of the stats are run with:

python -m torch.distributed.launch \
           --nproc_per_node=8 --use_env train.py \
           --model deit_large_patch16_LS \
           --data-path ${imagenet_dir} \
           --output_dir ${output_dir} \
           --nb-classes 1000 \
           --batch 32 \
           --lr 3e-3 --epochs 1 \
           --weight-decay 0.05 --sched cosine \
           --input-size 224 \
           --reprob 0.0 \
           --smoothing 0.0 --warmup-epochs 5 --drop 0.0 \
           --seed 0 --opt lamb \
           --warmup-lr 1e-6 --mixup .8 --drop-path 0.6 --cutmix 1.0 \
           --unscale-lr --repeated-aug \
           --bce-loss  \
           --color-jitter 0.3 --ThreeAugment \
           --no-model-ema

Can anyone help?

@edrt you should not do

model = DDP(model, ...)
model = FSDP(model, ...)

please look carefully at the tutorial that you use, it actually does:

model = DDP(model(), ...)
fsdp_model = FSDP(model(), ...)

So, just remove DDP line

Thanks for your reply.
I tried the model() instead of model first, but the error came out with missing input of calling forward pass with model().
I also tried to remove the DDP line, but the gpu memory usage didn’t reduce as expected.

@Yanli_Zhao could you please help?