OOM when using multi-gpu

Hi, everyone!
Here I trained a model using fairseq 3090 GPUs and the default adam trainer is used (fairseq-train command). It went well on a single GPU, not OOM and other errors. But when I tried to use two GPUs, OOM occurred like below.
According to traceback, it seemed to occur in the optimizer step. It was strange that device 0 is allocated 0 memory while device 1 is allocated large memory.

| WARNING | fairseq.trainer | OOM: Ran out of memory with exception: CUD
A out of memory. Tried to allocate 3.02 GiB (GPU 1; 23.70 GiB total capacity; 
16.92 GiB already allocated; 1019.69 MiB free; 21.03 GiB reserved in total by 
PyTorch)                                                                      
2022-04-14 02:39:35 | WARNING | fairseq.trainer | |===========================
================================================|                             
|                  PyTorch CUDA memory summary, device ID 0                 | 
|---------------------------------------------------------------------------| 
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         | 
|===========================================================================| 
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  | 
|---------------------------------------------------------------------------| 
| Allocated memory      |       0 B  |       0 B  |       0 B  |       0 B  | 
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  | 
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  | 
|---------------------------------------------------------------------------| 
| Active memory         |       0 B  |       0 B  |       0 B  |       0 B  | 
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  | 
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  | 
|---------------------------------------------------------------------------| 
| GPU reserved memory   |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Allocations           |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|-------------------------------------------------------------------[115/1294]
| Active memory         |       0 B  |       0 B  |       0 B  |       0 B  | 
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  | 
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  | 
|---------------------------------------------------------------------------| 
| GPU reserved memory   |       0 B  |       0 B  |       0 B  |       0 B  | 
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Allocations           |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Active allocs         |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       0    |       0    |       0    |       0    | 
|       from large pool |       0    |       0    |       0    |       0    | 
|       from small pool |       0    |       0    |       0    |       0    | 
|---------------------------------------------------------------------------| 
| Non-releasable allocs |       0    |       0    |       0    |       0    | 
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|===========================================================================|

2022-04-14 02:39:35 | WARNING | fairseq.trainer | |===========================
================================================|
|                  PyTorch CUDA memory summary, device ID 1                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 1            |        cudaMalloc retries: 1         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   17329 MB |   20441 MB |  315650 MB |  298320 MB |
|       from large pool |   17326 MB |   20437 MB |  303570 MB |  286244 MB |
|       from small pool |       3 MB |     280 MB |   12079 MB |   12076 MB |
|---------------------------------------------------------------------------|
| Active memory         |   17329 MB |   20441 MB |  315650 MB |  298320 MB |
|       from large pool |   17326 MB |   20437 MB |  303570 MB |  286244 MB |
|       from small pool |       3 MB |     280 MB |   12079 MB |   12076 MB |
|---------------------------------------------------------------------------|
|--------------------------------------------------------------------[70/1294]
| GPU reserved memory   |   21530 MB |   21830 MB |   43546 MB |   22016 MB |
|       from large pool |   21464 MB |   21624 MB |   41444 MB |   19980 MB |
|       from small pool |      66 MB |     292 MB |    2102 MB |    2036 MB |
|---------------------------------------------------------------------------|
| Non-releasable memory |    4200 MB |    4202 MB |  264811 MB |  260611 MB |
|       from large pool |    4137 MB |    4139 MB |  250471 MB |  246333 MB |
|       from small pool |      62 MB |     112 MB |   14340 MB |   14277 MB |
|---------------------------------------------------------------------------|
| Allocations           |    2280    |    3778    |  177529    |  175249    |
|       from large pool |     929    |    1345    |   52873    |   51944    |
|       from small pool |    1351    |    2537    |  124656    |  123305    |
|---------------------------------------------------------------------------|
| Active allocs         |    2280    |    3778    |  177529    |  175249    |
|       from large pool |     929    |    1345    |   52873    |   51944    |
|       from small pool |    1351    |    2537    |  124656    |  123305    |
|---------------------------------------------------------------------------|
| GPU reserved segments |     128    |     265    |    1846    |    1718    |
|       from large pool |      95    |     162    |     795    |     700    |
|       from small pool |      33    |     146    |    1051    |    1018    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |     135    |     181    |   85994    |   85859    |
|       from large pool |      53    |      68    |   31036    |   30983    |
|       from small pool |      82    |     146    |   54958    |   54876    |
|===========================================================================|
2022-04-14 02:39:35 | ERROR | fairseq.trainer | OOM during optimization, irrec
overable
Traceback (most recent call last):                                            
  File "/home/xjw/miniconda3/envs/cliff/bin/fairseq-train", line 8, in <module
>
    sys.exit(cli_main())               
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/fairseq_cl
i/train.py", line 392, in cli_main
    distributed_utils.call_main(cfg, main)                                    
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/fairseq/di
stributed_utils.py", line 318, in call_main
    cfg.distributed_training.distributed_world_size,                          
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/torch/mult
iprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn
')
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/torch/mult
iprocessing/spawn.py", line 188, in start_processes
    while not context.join():          
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/torch/mult
iprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:                           

-- Process 1 terminated with the following error: 
Traceback (most recent call last):                                            
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/torch/mult
iprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)                       
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/fairseq/di
stributed_utils.py", line 300, in distributed_main
    main(cfg, **kwargs)                
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/fairseq_cl
i/train.py", line 130, in main
    valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/contextlib.py", line 52,
 in inner
    return func(*args, **kwds)         
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/fairseq_cl
i/train.py", line 219, in train
    log_output = trainer.train_step(samples)                                  
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/contextlib.py", line 52,
 in inner
    return func(*args, **kwds)         
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/fairseq/tr
ainer.py", line 674, in train_step
    raise e        
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/fairseq/tr
ainer.py", line 647, in train_step
    self.optimizer.step()       
    self.optimizer.step(closure)
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/torch/optim/optimizer.py", line 89, in wrapper
    return func(*args, **kwargs)
  File "/home/xjw/miniconda3/envs/cliff/lib/python3.6/site-packages/fairseq/optim/adam.py", line 210, in step
    denom = exp_avg_sq.sqrt().add_(group["eps"])
RuntimeError: CUDA out of memory. Tried to allocate 3.02 GiB (GPU 1; 23.70 GiB total capacity; 16.92 GiB already allocated; 1019.69 MiB free; 21.03 GiB reserved in total by PyTorch)

My training script is like below, and I only changed DEVICE when using multi GPUs.

TOTAL_NUM_UPDATES=20000
WARMUP_UPDATES=500
LR=3e-05
MAX_TOKENS=1024
UPDATE_FREQ=4
BART_PATH=$1
DATA_DIR=$2
USER_DIR=$3
SAVE_PATH=$4
TENSOR_LOG_PATH=$5
DEVICES=4,5

CUDA_VISIBLE_DEVICES=$DEVICES fairseq-train $DATA_DIR \
    --facets Purpose,Method,Findings \
    --max-epoch 10 \
    --tensorboard-logdir $TENSOR_LOG_PATH \
    --restore-file $BART_PATH --save-dir $SAVE_PATH \
    --max-tokens $MAX_TOKENS \
    --task divide_translation \
    --source-lang source --target-lang Purpose.target \
    --truncate-source \
    --layernorm-embedding \
    --share-all-embeddings \
    --share-decoder-input-output-embed \
    --reset-optimizer --reset-dataloader --reset-meters \
    --required-batch-size-multiple 1 \
    --arch divide_bart_large \
    --criterion divide_loss \
    --label-smoothing 0.1 \
    --fixed-validation-seed 7 \
    --dropout 0.1 --attention-dropout 0.1 \
    --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
    --clip-norm 0.1 \
    --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
    --fp16 --update-freq $UPDATE_FREQ \
    --skip-invalid-size-inputs-valid-test \
    --no-save-optimizer-state \
    --find-unused-parameters \
    --user-dir $USER_DIR;