hi, I’m working on refactor the deit code to integrate FSDP to train deit_large with larger single gpu batch.
After import related FSDP module, I simply change my code [ref this link] to
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model = FSDP(
model,
fsdp_auto_wrap_policy=default_auto_wrap_policy,
)
then I hit two problems:
- the fsdp model cannot run with model EMA.
- as metioned above, the GPU memory usage is not reduced as expected.
Here is an example for Q2:
run with FSDP:
GPU memory by nvidia-smi:
18191MiB
terminal output by deit:
Epoch: [0] [ 200/40036] eta: 10:42:06 lr: 0.000001 loss: 0.5857 (0.6690)
time: 1.0125 data: 0.0036 max mem: 13402
run without FSDP:
GPU memory by nvidia-smi:
19569MiB
terminal output by deit:
Epoch: [0] [ 200/5004] eta: 1:47:13 lr: 0.000001 loss: 0.7104 (0.7275)
time: 1.1631 data: 0.0002 max mem: 17334
Both of the stats are run with:
python -m torch.distributed.launch \
--nproc_per_node=8 --use_env train.py \
--model deit_large_patch16_LS \
--data-path ${imagenet_dir} \
--output_dir ${output_dir} \
--nb-classes 1000 \
--batch 32 \
--lr 3e-3 --epochs 1 \
--weight-decay 0.05 --sched cosine \
--input-size 224 \
--reprob 0.0 \
--smoothing 0.0 --warmup-epochs 5 --drop 0.0 \
--seed 0 --opt lamb \
--warmup-lr 1e-6 --mixup .8 --drop-path 0.6 --cutmix 1.0 \
--unscale-lr --repeated-aug \
--bce-loss \
--color-jitter 0.3 --ThreeAugment \
--no-model-ema
Can anyone help?