When using FSDP ssh disconnected

i am using FSDP from pytorch distributed
when i set batch size 80 (98% my GPU memory allocation), my remote ssh is disconnected and tmux server also killed.

but when i am reduce batch size a little, it works well.

and this is my version

Linux

Pytorch
python = “^3.10”
torchvision = “^0.16.0”
torch = {version = “^2.1.0”, extras = [“cu121,cp310”]}
matplotlib = “^3.8.1”
torchaudio = “^2.1.0”
tensorboard = “^2.15.1”
psutil = “^5.9.7”
scikit-image = “^0.22.0”
pyyaml = “^6.0.1”
fastapi = {extras = [“standard”], version = “^0.106.0”}
pytz = “^2023.3.post1”
opencv-python = “^4.8.1.78”
uvicorn = “^0.25.0”
joblib = “^1.3.2”
scipy = “^1.11.4”
pandas = “^2.1.4”
seaborn = “^0.13.0”
thop = “^0.1.1.post2209072238”
protobuf = “3.20.0”
tqdm = “^4.66.1”
python-multipart = “^0.0.6”
librosa = “0.7.0”

Driver and CUDA
Driver Version: 535.154.05 CUDA Version: 12.2

Did you check any logs to see why your session drops? E.g. are you running out of host RAM (if you are using CPU offloading)?

yes i am using CPU offloading like this.

model = FSDP(model, auto_wrap_policy=functools.partial(
    size_based_auto_wrap_policy, min_num_params=1e-2),
             cpu_offload=CPUOffload(offload_params=True),
             mixed_precision=data_type,
             backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
             sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
             device_id=rank)

But i don’t know why but it works well now.