Hi, I’m training to train my model with fsdp and activation checkpoint following the tutorial here.
My code runs goods on my test server (with 2 1080), but would fail with segmentation fault on our A100 server with ngc docker.
================================
==== backtrace (tid: 116576) ====
0 0x0000000000043090 killpg() ???:0
1 0x00000000001189bc mprotect() ???:0
2 0x00000000012739bc at::MapAllocator::close() ???:0
3 0x0000000001273c0f at::MapAllocator::~MapAllocator() ???:0
4 0x0000000001273c5d at::MapAllocator::~MapAllocator() ???:0
5 0x00000000005072fc c10::StorageImpl::release_resources() :0
6 0x000000000003b835 c10::intrusive_ptr<c10::StorageImpl, c10::detail::intrusive_target_default_null_type<c10::StorageImpl> >::reset_() :0
7 0x00000000000350b6 c10::TensorImpl::~TensorImpl() ???:0
8 0x00000000000351dd c10::TensorImpl::~TensorImpl() TensorImpl.cpp:0
9 0x0000000000786958 THPVariable_clear() python_variable.cpp:0
10 0x0000000000786ce5 THPVariable_subclass_dealloc() ???:0
11 0x00000000005ce863 _PyDict_Next() ???:0
12 0x00000000005d176c _PyDict_GetItemStringWithError() ???:0
13 0x00000000005ae9ca PySet_Contains() ???:0
14 0x0000000000570a60 _PyEval_EvalFrameDefault() ???:0
15 0x0000000000569d8a _PyEval_EvalCodeWithName() ???:0
16 0x00000000005f60c3 _PyFunction_Vectorcall() ???:0
17 0x000000000056bab6 _PyEval_EvalFrameDefault() ???:0
18 0x0000000000569d8a _PyEval_EvalCodeWithName() ???:0
19 0x00000000005f60c3 _PyFunction_Vectorcall() ???:0
20 0x00000000005f52b2 PyObject_Call() ???:0
21 0x000000000056d2bc _PyEval_EvalFrameDefault() ???:0
22 0x00000000005f5ee6 _PyFunction_Vectorcall() ???:0
23 0x000000000056bbe1 _PyEval_EvalFrameDefault() ???:0
24 0x00000000005f5ee6 _PyFunction_Vectorcall() ???:0
25 0x000000000056bbe1 _PyEval_EvalFrameDefault() ???:0
26 0x00000000005f5ee6 _PyFunction_Vectorcall() ???:0
27 0x000000000050b32c PyMethod_New() ???:0
28 0x00000000005f52b2 PyObject_Call() ???:0
29 0x00000000006568ec PyInit__thread() ???:0
30 0x0000000000677e08 _PyFloat_FormatAdvancedWriter() ???:0
31 0x0000000000008609 start_thread() ???:0
32 0x000000000011f133 clone() ???:0
=================================
I tried to gdb into core.dump files, and it looks like the problem is related to the pin_memory flag.
# py-bt gives
Traceback (most recent call first):
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/pin_memory.py", line 295, in do_one_step
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/pin_memory.py", line 52, in _pin_memory_loop
do_one_step()
File "/usr/lib/python3.8/threading.py", line 870, in run
self._target(*self._args, **self._kwargs)
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
self._bootstrap_inner()
So I also tried to set the pin_memory flag to False, now the code works, but is much slower. I’m wondering why setting pin_memory to True would give the above error.
I’m more likely to run into the above error with a large bs and num workers. Setting bs to 64 and num_workers=1 also works.
Environment:
CPU: AMD EPYC 7K62
GPU: 8*A100
Driver: 470.82.01
cuda: 11.4
Steps to reproduce:
docker run --name ngc --rm -it --gpus all --shm-size=50gb nvcr.io/nvidia/pytorch:22.12-py3
python run.py
My training code (run.py):
import os
import time
import functools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
# from transformers import BertModel, BertConfig, BertLayer
from torchvision.models.vision_transformer import Encoder, EncoderBlock, MLPBlock
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper as checkpoint_wrapper_pytorch,
CheckpointImpl,
)
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
import random
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
class MyDataset(torch.utils.data.Dataset):
def __init__(self, ):
self.x = -1
self.test_mode = False
def __len__(self, ):
return 1024000
# return 102400
def __getitem__(self, idx):
# img = torch.ones(3, 224, 224)
# return {"img": img}
x1 = {
"img": torch.rand(3, 224, 224),
# "doc_input_ids": torch.rand(3, 224, 224),
"doc_input_ids": torch.rand(78),
}
return x1
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def train(data, model, rank, world_size, optimizer):
dataset = MyDataset()
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size)
# works if pin_memory=False, but slow
# seg fault if pin_memory=True
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=512, num_workers=8, sampler=sampler, pin_memory=True)
model.train()
for ind, data in enumerate(dataloader):
optimizer.zero_grad()
data = data['img']
data = data.to(rank)
if rank == 0: print(data.shape)
output = model(data)
loss = output[0].mean()
if rank == 0: print(ind, loss)
loss.backward()
optimizer.step()
def fsdp_main(rank, world_size):
setup(rank, world_size)
torch.cuda.set_device(rank)
bs = 64
model = torchvision.models.vit_b_16(pretrained=False)
data = torch.randn((bs, 3, 224, 224))
attn_layers = [EncoderBlock, ]
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls = { *attn_layers }
)
mixed_precision_policy = None
model = FSDP(model.to(rank),
mixed_precision=mixed_precision_policy,
device_id=rank,
auto_wrap_policy=auto_wrap_policy)
if True:
check_fn = lambda submodule: \
any([isinstance(submodule, x) for x in attn_layers])
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper_pytorch,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn = non_reentrant_wrapper,
check_fn=check_fn,
)
optimizer = optim.Adadelta(model.parameters(), lr=1e-2)
train(data, model, rank, world_size, optimizer)
dist.destroy_process_group()
if __name__ == '__main__':
WORLD_SIZE = torch.cuda.device_count()
mp.spawn(fsdp_main,
args=(WORLD_SIZE, ),
nprocs=WORLD_SIZE,
join=True)
Many thanks in advance!