CUDA out of memory even after using DistributedDataParallel

I try to train a big model on HPC using SLURM and got torch.cuda.OutOfMemoryError: CUDA out of memory even after using FSDP. I use accelerate from the Hugging Face to set up. Below is my error:

File "/project/p_trancal/CamLidCalib_Trans/Models/Encoder.py", line 45, in forward
    atten_out, atten_out_para = self.atten(x,x,x, attn_mask = attn_mask)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1126, in forward
    attn_mask = F._canonical_mask(
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/nn/functional.py", line 5115, in _canonical_mask
    torch.zeros_like(mask, dtype=target_type)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 29.07 GiB. GPU 0 has a total capacity of 39.43 GiB of which 25.15 GiB is free. Including non-PyTorch memory, this process has 14.27 GiB memory in use. Of the allocated memory 11.74 GiB is allocated by PyTorch, and 931.17 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Traceback (most recent call last):
[2024-03-31 20:37:40,661] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 806649 closing signal SIGTERM
[2024-03-31 20:37:40,674] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 805748 closing signal SIGTERM
[2024-03-31 20:37:41,125] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 806645) of binary: /project/p_trancal/trsclbjob/bin/python
Traceback (most recent call last):
  File "/project/p_trancal/trsclbjob/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
    args.func(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1044, in launch_command
    multi_gpu_launcher(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/launch.py", line 702, in multi_gpu_launcher
    distrib_run.run(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/project/p_trancal/CamLidCalib_Trans/train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-03-31_20:37:40
  host      : cn03.head.komondor.hpc.einfra.hu
  rank      : 4 (local_rank: 0)
  exitcode  : 1 (pid: 806645)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
[2024-03-31 20:37:41,138] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 1 (pid: 805749) of binary: /project/p_trancal/trsclbjob/bin/python
Traceback (most recent call last):
  File "/project/p_trancal/trsclbjob/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
    args.func(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1044, in launch_command
    multi_gpu_launcher(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/launch.py", line 702, in multi_gpu_launcher
    distrib_run.run(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/project/p_trancal/CamLidCalib_Trans/train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-03-31_20:37:40
  host      : cn04.head.komondor.hpc.einfra.hu
  rank      : 7 (local_rank: 1)
  exitcode  : 1 (pid: 805749)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
srun: error: cn03: task 2: Exited with exit code 1
srun: launch/slurm: _step_signal: Terminating StepId=4178435.0
slurmstepd: error: *** STEP 4178435.0 ON cn01 CANCELLED AT 2024-03-31T20:37:41 ***
[2024-03-31 20:37:41,433] torch.distributed.elastic.agent.server.api: [WARNING] Received Signals.SIGTERM death signal, shutting down workers
[2024-03-31 20:37:41,433] torch.distributed.elastic.agent.server.api: [WARNING] Received Signals.SIGTERM death signal, shutting down workers
[2024-03-31 20:37:41,433] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1211131 closing signal SIGTERM
Traceback (most recent call last):
  File "/project/p_trancal/trsclbjob/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
    args.func(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1044, in launch_command
    multi_gpu_launcher(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/launch.py", line 702, in multi_gpu_launcher
    distrib_run.run(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 259, in launch_agent
    result = agent.run()
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/elastic/metrics/api.py", line 123, in wrapper
    result = f(*args, **kwargs)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 727, in run
    result = self._invoke_run(role)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 868, in _invoke_run
    time.sleep(monitor_interval)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
    raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
torch.distributed.elastic.multiprocessing.api.SignalException: Process 807541 got signal: 15
srun: error: cn04: task 3: Exited with exit code 1
srun: error: cn02: task 1: Exited with exit code 1
Traceback (most recent call last):
  File "/project/p_trancal/trsclbjob/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
    args.func(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1044, in launch_command
    multi_gpu_launcher(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/accelerate/commands/launch.py", line 702, in multi_gpu_launcher
    distrib_run.run(args)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 259, in launch_agent
    result = agent.run()
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/elastic/metrics/api.py", line 123, in wrapper
    result = f(*args, **kwargs)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 727, in run
    result = self._invoke_run(role)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 868, in _invoke_run
    time.sleep(monitor_interval)
  File "/project/p_trancal/trsclbjob/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
    raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
torch.distributed.elastic.multiprocessing.api.SignalException: Process 1210732 got signal: 15
srun: error: cn01: task 0: Exited with exit code 1

My SLURM script:
#!/bin/bash
#SBATCH --job-name=Trial
#SBATCH --partition=ai
#SBATCH --time=03:00:00
#SBATCH -N 4
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:4
#SBATCH --mem=150GB

export LAUNCHER=“accelerate launch
–config_file CamLidCalib_Trans/config/disGPU_accelerate.yaml
–num_processes 8
–num_machines $SLURM_NNODES
–machine_rank $SLURM_PROCID
–main_process_ip $head_node_ip
–main_process_port $UID
–rdzv_backend c10d
"
export SCRIPT=”/project/p_trancal/CamLidCalib_Trans/train.py"
export CMD=“$LAUNCHER $SCRIPT”
NCCL_P2P_DISABLE=1 NCCL_IB_DISABLE=1 srun $CMD

Main function:

print('Start main')
args = get_parser()
num_gpus = torch.cuda.device_count()

transformer_auto_wrapper_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={
        EncoderBlock,
    },
)

# Pass the advanced FSDP settings not part of the accelerate config by creating fsdp_plugin
fsdp_plugin = FullyShardedDataParallelPlugin(
    auto_wrap_policy = transformer_auto_wrapper_policy,
    sharding_strategy = ShardingStrategy.FULL_SHARD,
    mixed_precision_policy = MixedPrecision(reduce_dtype =torch.float16),
)
# Initialize accelerator
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
print('Check plugin: ', accelerator.state.fsdp_plugin)

device = accelerator.device
model = TransformerCalib(device=device, args=args)
model = accelerator.prepare_model(model)

dataSet = PreKittiData(root_dir=args.data_root, args=args)
valid_loader = DataLoader(dataSet.getData(valid=False), batch_size=args.batch_size, drop_last=True, num_workers=4)

optimizer = torch.optim.Adam(model.parameters(), lr=float(args.learning_rate))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.sche_step_size, gamma=args.sche_gamma)
optimizer, valid_loader, scheduler = accelerator.prepare(optimizer, valid_loader, scheduler)

Result of FSDP plugin print:
Check plugin: FullyShardedDataParallelPlugin(sharding_strategy=<ShardingStrategy.FULL_SHARD: 1>, backward_prefetch=None, mixed_precision_policy=MixedPrecision(param_dtype=None, reduce_dtype=torch.float16, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>,)), auto_wrap_policy=functools.partial(<function transformer_auto_wrap_policy at 0x7f98b544fbe0>, transformer_layer_cls={<class 'Models.Encoder.EncoderBlock'>}), cpu_offload=CPUOffload(offload_params=True), ignored_modules=None, state_dict_type=<StateDictType.FULL_STATE_DICT: 1>, state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), limit_all_gathers=True, use_orig_params=True, param_init_fn=<function FullyShardedDataParallelPlugin.__post_init__.<locals>.<lambda> at 0x7f986527e200>, sync_module_states=True, forward_prefetch=False, activation_checkpointing=False)

The errors happen in the first forward calculation so most probably not because of Backward or optimizer setup. I also noticed that no matter how I change the configuration of FSDP (using different Shard Strategy or setting different Mixed Precision), it doesn’t seem to have any effect on model size (The used memory is the same and it also fails on the same line of code).

My question:
Is my configuration wrong somewhere? (SLURM)
Shall I use the activate checkpoint or other features of FSDP?
How can I know whether my model is actually distributed across GPU? And how many GPUs are used by one model/trainer? How to configure?
Shall I change from a Transformer-wrapper to a size-based one?
Could I have some advice on config FSDP? Which feature shall be turned on?

PS: My batch size is 1 and I want to keep the model size. :melting_face: