Running fsdp inside ngc docker with pin_memory gives segmentation fault

Hi, I’m training to train my model with fsdp and activation checkpoint following the tutorial here.
My code runs goods on my test server (with 2 1080), but would fail with segmentation fault on our A100 server with ngc docker.

================================
==== backtrace (tid: 116576) ====
 0 0x0000000000043090 killpg()  ???:0
 1 0x00000000001189bc mprotect()  ???:0
 2 0x00000000012739bc at::MapAllocator::close()  ???:0
 3 0x0000000001273c0f at::MapAllocator::~MapAllocator()  ???:0
 4 0x0000000001273c5d at::MapAllocator::~MapAllocator()  ???:0
 5 0x00000000005072fc c10::StorageImpl::release_resources()  :0
 6 0x000000000003b835 c10::intrusive_ptr<c10::StorageImpl, c10::detail::intrusive_target_default_null_type<c10::StorageImpl> >::reset_()  :0
 7 0x00000000000350b6 c10::TensorImpl::~TensorImpl()  ???:0
 8 0x00000000000351dd c10::TensorImpl::~TensorImpl()  TensorImpl.cpp:0
 9 0x0000000000786958 THPVariable_clear()  python_variable.cpp:0
10 0x0000000000786ce5 THPVariable_subclass_dealloc()  ???:0
11 0x00000000005ce863 _PyDict_Next()  ???:0
12 0x00000000005d176c _PyDict_GetItemStringWithError()  ???:0
13 0x00000000005ae9ca PySet_Contains()  ???:0
14 0x0000000000570a60 _PyEval_EvalFrameDefault()  ???:0
15 0x0000000000569d8a _PyEval_EvalCodeWithName()  ???:0
16 0x00000000005f60c3 _PyFunction_Vectorcall()  ???:0
17 0x000000000056bab6 _PyEval_EvalFrameDefault()  ???:0
18 0x0000000000569d8a _PyEval_EvalCodeWithName()  ???:0
19 0x00000000005f60c3 _PyFunction_Vectorcall()  ???:0
20 0x00000000005f52b2 PyObject_Call()  ???:0
21 0x000000000056d2bc _PyEval_EvalFrameDefault()  ???:0
22 0x00000000005f5ee6 _PyFunction_Vectorcall()  ???:0
23 0x000000000056bbe1 _PyEval_EvalFrameDefault()  ???:0
24 0x00000000005f5ee6 _PyFunction_Vectorcall()  ???:0
25 0x000000000056bbe1 _PyEval_EvalFrameDefault()  ???:0
26 0x00000000005f5ee6 _PyFunction_Vectorcall()  ???:0
27 0x000000000050b32c PyMethod_New()  ???:0
28 0x00000000005f52b2 PyObject_Call()  ???:0
29 0x00000000006568ec PyInit__thread()  ???:0
30 0x0000000000677e08 _PyFloat_FormatAdvancedWriter()  ???:0
31 0x0000000000008609 start_thread()  ???:0
32 0x000000000011f133 clone()  ???:0
=================================

I tried to gdb into core.dump files, and it looks like the problem is related to the pin_memory flag.

# py-bt gives
Traceback (most recent call first):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/pin_memory.py", line 295, in do_one_step
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/pin_memory.py", line 52, in _pin_memory_loop
    do_one_step()
  File "/usr/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
    self._bootstrap_inner()

So I also tried to set the pin_memory flag to False, now the code works, but is much slower. I’m wondering why setting pin_memory to True would give the above error.
I’m more likely to run into the above error with a large bs and num workers. Setting bs to 64 and num_workers=1 also works.

Environment:
CPU: AMD EPYC 7K62
GPU: 8*A100
Driver: 470.82.01
cuda: 11.4

Steps to reproduce:

docker run --name ngc --rm -it --gpus all --shm-size=50gb nvcr.io/nvidia/pytorch:22.12-py3
python run.py

My training code (run.py):

import os
import time
import functools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
# from transformers import BertModel, BertConfig, BertLayer
from torchvision.models.vision_transformer import Encoder, EncoderBlock, MLPBlock

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    size_based_auto_wrap_policy,
    enable_wrap,
    wrap,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing,
    checkpoint_wrapper as checkpoint_wrapper_pytorch,
    CheckpointImpl,
)
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler

import random
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, ):
        self.x = -1
        self.test_mode = False
    def __len__(self, ):
        return 1024000
        # return 102400
    def __getitem__(self, idx):
        # img = torch.ones(3, 224, 224)
        # return {"img": img}
        x1 = {
            "img": torch.rand(3, 224, 224),
            # "doc_input_ids": torch.rand(3, 224, 224),
            "doc_input_ids": torch.rand(78),
        }
        return x1

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def train(data, model, rank, world_size, optimizer):
    dataset = MyDataset()
    sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size)
    # works if pin_memory=False, but slow
    # seg fault if pin_memory=True
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=512, num_workers=8, sampler=sampler, pin_memory=True)
    model.train()

    for ind, data in enumerate(dataloader):
        optimizer.zero_grad()
        data = data['img']
        data = data.to(rank)
        if rank == 0: print(data.shape)
        output = model(data)
        loss = output[0].mean()
        if rank == 0: print(ind, loss)

        loss.backward()
        optimizer.step()


def fsdp_main(rank, world_size):
    setup(rank, world_size)
    torch.cuda.set_device(rank)

    bs = 64
    model = torchvision.models.vit_b_16(pretrained=False)
    data = torch.randn((bs, 3, 224, 224))
    attn_layers = [EncoderBlock, ]
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls = { *attn_layers }
    )
    mixed_precision_policy = None
    model = FSDP(model.to(rank),
            mixed_precision=mixed_precision_policy,
            device_id=rank,
            auto_wrap_policy=auto_wrap_policy)

    if True:
        check_fn = lambda submodule: \
                any([isinstance(submodule, x) for x in attn_layers])

        non_reentrant_wrapper = functools.partial(
            checkpoint_wrapper_pytorch,
            offload_to_cpu=False,
            checkpoint_impl=CheckpointImpl.NO_REENTRANT,
        )

        apply_activation_checkpointing(
            model,
            checkpoint_wrapper_fn = non_reentrant_wrapper,
            check_fn=check_fn,
        )

    optimizer = optim.Adadelta(model.parameters(), lr=1e-2)
    train(data, model, rank, world_size, optimizer)
    dist.destroy_process_group()

if __name__ == '__main__':
    WORLD_SIZE = torch.cuda.device_count()
    mp.spawn(fsdp_main,
        args=(WORLD_SIZE, ),
        nprocs=WORLD_SIZE,
        join=True)

Many thanks in advance!

Are you seeing the same issue using the latest nightly binaries or without FSDP?

@ptrblck
Hi, ddp works fine for both ngc docker(1.14.0a0+410ce96) and nightly build.
While for fsdp, I can observer one similar behavior with the nightly build. The network can run two or three steps, then it hangs with one gpu util = 0 (from nvidia-smi), and then segmentation fault.

# some ngc logs are skipped, if you want I can also paste it here.
qs-665-6120-vjob-0:14815:20418 [0] NCCL INFO Connected all trees
qs-665-6120-vjob-0:14815:20418 [0] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
qs-665-6120-vjob-0:14815:20418 [0] NCCL INFO 24 coll channels, 32 p2p channels, 32 p2p channels per peer
qs-665-6120-vjob-0:14822:20428 [7] NCCL INFO Connected all trees
qs-665-6120-vjob-0:14822:20428 [7] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
qs-665-6120-vjob-0:14822:20428 [7] NCCL INFO 24 coll channels, 32 p2p channels, 32 p2p channels per peer
qs-665-6120-vjob-0:14816:20422 [1] NCCL INFO Connected all trees
qs-665-6120-vjob-0:14816:20422 [1] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
qs-665-6120-vjob-0:14816:20422 [1] NCCL INFO 24 coll channels, 32 p2p channels, 32 p2p channels per peer
qs-665-6120-vjob-0:14819:20431 [4] NCCL INFO Connected all trees
qs-665-6120-vjob-0:14819:20431 [4] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
qs-665-6120-vjob-0:14819:20431 [4] NCCL INFO 24 coll channels, 32 p2p channels, 32 p2p channels per peer
qs-665-6120-vjob-0:14817:20424 [2] NCCL INFO Connected all trees
qs-665-6120-vjob-0:14817:20424 [2] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
qs-665-6120-vjob-0:14817:20424 [2] NCCL INFO 24 coll channels, 32 p2p channels, 32 p2p channels per peer
qs-665-6120-vjob-0:14818:20432 [3] NCCL INFO Connected all trees
qs-665-6120-vjob-0:14818:20432 [3] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
qs-665-6120-vjob-0:14818:20432 [3] NCCL INFO 24 coll channels, 32 p2p channels, 32 p2p channels per peer
qs-665-6120-vjob-0:14820:20426 [5] NCCL INFO Connected all trees
qs-665-6120-vjob-0:14820:20426 [5] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
qs-665-6120-vjob-0:14820:20426 [5] NCCL INFO 24 coll channels, 32 p2p channels, 32 p2p channels per peer
qs-665-6120-vjob-0:14821:20420 [6] NCCL INFO Connected all trees
qs-665-6120-vjob-0:14821:20420 [6] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
qs-665-6120-vjob-0:14821:20420 [6] NCCL INFO 24 coll channels, 32 p2p channels, 32 p2p channels per peer
qs-665-6120-vjob-0:14820:20426 [5] NCCL INFO comm 0x270db2a0 rank 5 nranks 8 cudaDev 5 busId a7000 - Init COMPLETE
qs-665-6120-vjob-0:14819:20431 [4] NCCL INFO comm 0x267dec60 rank 4 nranks 8 cudaDev 4 busId a2000 - Init COMPLETE
qs-665-6120-vjob-0:14818:20432 [3] NCCL INFO comm 0x26bcc2e0 rank 3 nranks 8 cudaDev 3 busId 69000 - Init COMPLETE
qs-665-6120-vjob-0:14821:20420 [6] NCCL INFO comm 0x2798bed0 rank 6 nranks 8 cudaDev 6 busId e0000 - Init COMPLETE
qs-665-6120-vjob-0:14816:20422 [1] NCCL INFO comm 0x27c4ea90 rank 1 nranks 8 cudaDev 1 busId 2b000 - Init COMPLETE
qs-665-6120-vjob-0:14817:20424 [2] NCCL INFO comm 0x27906b30 rank 2 nranks 8 cudaDev 2 busId 64000 - Init COMPLETE
qs-665-6120-vjob-0:14815:20418 [0] NCCL INFO comm 0x2630afd0 rank 0 nranks 8 cudaDev 0 busId 25000 - Init COMPLETE
qs-665-6120-vjob-0:14822:20428 [7] NCCL INFO comm 0x27f46120 rank 7 nranks 8 cudaDev 7 busId e6000 - Init COMPLETE
0 tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([512, 3, 224, 224])
1 tensor(-0.0067, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([512, 3, 224, 224])
2 tensor(-0.0135, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([512, 3, 224, 224])
# here it hangs for some time
==== backtrace (tid:  15872) ====
 0 0x0000000000014420 __funlockfile()  ???:0
 1 0x00000000001189bc mprotect()  ???:0
 2 0x00000000015532c4 at::MapAllocator::close()  ???:0
 3 0x00000000015534cb at::MapAllocator::~MapAllocator()  ???:0
 4 0x0000000001553549 at::MapAllocator::~MapAllocator()  ???:0
 5 0x00000000004cf138 c10::StorageImpl::release_resources()  :0
 6 0x0000000000041c15 c10::intrusive_ptr<c10::StorageImpl, c10::detail::intrusive_target_default_null_type<c10::StorageImpl> >::reset_()  :0
 7 0x000000000003816e c10::TensorImpl::~TensorImpl()  ???:0
 8 0x0000000000038289 c10::TensorImpl::~TensorImpl()  TensorImpl.cpp:0
 9 0x0000000000758048 THPVariable_clear()  python_variable.cpp:0
10 0x00000000007583f5 THPVariable_subclass_dealloc()  ???:0
11 0x00000000004c2a0f _Py_Dealloc()  /usr/local/src/conda/python-3.8.16/Objects/object.c:2215
12 0x00000000004c2863 dictkeys_decref()  /usr/local/src/conda/python-3.8.16/Objects/dictobject.c:324
13 0x00000000004c2863 dict_dealloc()  /usr/local/src/conda/python-3.8.16/Objects/dictobject.c:1998
14 0x00000000004d7a5b _Py_Dealloc()  /usr/local/src/conda/python-3.8.16/Objects/object.c:2215
15 0x00000000004d7a5b _Py_DECREF()  /usr/local/src/conda/python-3.8.16/Include/object.h:478
16 0x00000000004d7a5b _Py_XDECREF()  /usr/local/src/conda/python-3.8.16/Include/object.h:541
17 0x00000000004d7a5b tupledealloc()  /usr/local/src/conda/python-3.8.16/Objects/tupleobject.c:247
18 0x00000000004cce1f _Py_Dealloc()  /usr/local/src/conda/python-3.8.16/Objects/object.c:2215
19 0x00000000004cce1f _Py_DECREF()  /usr/local/src/conda/python-3.8.16/Include/object.h:478
20 0x00000000004cce1f _Py_XDECREF()  /usr/local/src/conda/python-3.8.16/Include/object.h:541
21 0x00000000004cce1f _PyEval_EvalFrameDefault()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:1357
22 0x00000000004c6fe5 PyEval_EvalFrameEx()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:741
23 0x00000000004db2ac _PyFunction_Vectorcall()  /usr/local/src/conda/python-3.8.16/Objects/call.c:436
24 0x00000000004c8a47 _PyObject_Vectorcall()  /usr/local/src/conda/python-3.8.16/Include/cpython/abstract.h:127
25 0x00000000004c8a47 _Py_CheckFunctionResult()  /usr/local/src/conda/python-3.8.16/Objects/call.c:25
26 0x00000000004c8a47 _PyObject_Vectorcall()  /usr/local/src/conda/python-3.8.16/Include/cpython/abstract.h:128
27 0x00000000004c8a47 call_function()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:4963
28 0x00000000004c8a47 _PyEval_EvalFrameDefault()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:3500
29 0x00000000004c6fe5 PyEval_EvalFrameEx()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:741
30 0x00000000004db2ac _PyFunction_Vectorcall()  /usr/local/src/conda/python-3.8.16/Objects/call.c:436
31 0x00000000004ed77e PyVectorcall_Call()  /usr/local/src/conda/python-3.8.16/Objects/call.c:200
32 0x00000000004ed77e PyObject_Call()  /usr/local/src/conda/python-3.8.16/Objects/call.c:228
33 0x00000000004ca03a do_call_core()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:5010
34 0x00000000004ca03a _PyEval_EvalFrameDefault()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:3559
35 0x00000000004db216 PyEval_EvalFrameEx()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:741
36 0x00000000004db216 _PyFunction_Vectorcall()  /usr/local/src/conda/python-3.8.16/Objects/call.c:411
37 0x00000000004c8b7d _PyObject_Vectorcall()  /usr/local/src/conda/python-3.8.16/Include/cpython/abstract.h:127
38 0x00000000004c8b7d _Py_CheckFunctionResult()  /usr/local/src/conda/python-3.8.16/Objects/call.c:25
39 0x00000000004c8b7d _PyObject_Vectorcall()  /usr/local/src/conda/python-3.8.16/Include/cpython/abstract.h:128
40 0x00000000004c8b7d call_function()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:4963
41 0x00000000004c8b7d _PyEval_EvalFrameDefault()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:3486
42 0x00000000004db216 PyEval_EvalFrameEx()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:741
43 0x00000000004db216 _PyFunction_Vectorcall()  /usr/local/src/conda/python-3.8.16/Objects/call.c:411
44 0x00000000004c8b7d _PyObject_Vectorcall()  /usr/local/src/conda/python-3.8.16/Include/cpython/abstract.h:127
45 0x00000000004c8b7d _Py_CheckFunctionResult()  /usr/local/src/conda/python-3.8.16/Objects/call.c:25
46 0x00000000004c8b7d _PyObject_Vectorcall()  /usr/local/src/conda/python-3.8.16/Include/cpython/abstract.h:128
47 0x00000000004c8b7d call_function()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:4963
48 0x00000000004c8b7d _PyEval_EvalFrameDefault()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:3486
49 0x00000000004db216 PyEval_EvalFrameEx()  /usr/local/src/conda/python-3.8.16/Python/ceval.c:741
50 0x00000000004db216 _PyFunction_Vectorcall()  /usr/local/src/conda/python-3.8.16/Objects/call.c:411
51 0x00000000004e96c5 _PyObject_Vectorcall()  /usr/local/src/conda/python-3.8.16/Include/cpython/abstract.h:127
52 0x00000000004e96c5 method_vectorcall()  /usr/local/src/conda/python-3.8.16/Objects/classobject.c:67
53 0x00000000004ed77e PyVectorcall_Call()  /usr/local/src/conda/python-3.8.16/Objects/call.c:200
54 0x00000000004ed77e PyObject_Call()  /usr/local/src/conda/python-3.8.16/Objects/call.c:228
55 0x00000000005a0ea7 t_bootstrap()  /usr/local/src/conda/python-3.8.16/Modules/_threadmodule.c:1002
56 0x00000000005a0e04 pythread_wrapper()  /usr/local/src/conda/python-3.8.16/Python/thread_pthread.h:232
57 0x0000000000008609 start_thread()  ???:0
=================================
Traceback (most recent call last):
  File "debug.py", line 143, in <module>
    mp.spawn(fsdp_main,
  File "/root/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/root/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/root/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 140, in join
    raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 6 terminated with signal SIGSEGV

I assume you’ve installed a nightly binary release and the tutorial was still failing or did you build PyTorch from source?

@ptrblck
I installed the nightly build via conda

conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch-nightly -c nvidia

and yes, still failing.