How to fix SIGSEGV in distributed training (i.e. DDP) 2

Hey,

This post may be very much related to this post. However, there is no solution currently available so here goes:

Problem: When I run the following training routine it sometimes finishes with and sometimes without a SIGSEGV error.

Environment:

  • python3.9.1
  • torch1.7.1
  • cuda11.0
  • cluster with nodes containing 8 GPUs (used by multiple users), jobs are submitted via LSF batch system

Code:

import os
import torch
import socket
import random
import argparse

from datetime import date
from contextlib import closing
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset

import numpy as np
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp


class my_dataset(Dataset):
    def __init__(self, n):
        self.data = torch.rand((n, 10), dtype=torch.float32)
        self.labels = torch.randint(0, 2, (n, 1))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {'data': self.data[idx, :], 'labels': self.labels[idx]}


def find_free_port():
    """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """

    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return str(s.getsockname()[1])


def arg_parse():
    desc = "Program to train a segmentation model."
    parser = argparse.ArgumentParser(description=desc)

    parser.add_argument('--devices',
                        type=str,
                        nargs='+',
                        default=['0'],
                        help='Devices to use for model training. Can be GPU IDs as in default or "cpu".')

    return parser.parse_args()


def train_model_distributed(rank_gpu, world_size, dataset, **kwargs):
    # initialize process group
    dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank_gpu)
    torch.cuda.set_device(rank_gpu)

    lr = 1e-3
    batch_size = 8
    batch_size_distr = int(batch_size / world_size)

    # set up dataloader
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank_gpu)

    train_loader = DataLoader(dataset, batch_size=batch_size_distr, shuffle=False, num_workers=2, prefetch_factor=2,
                              pin_memory=True, sampler=train_sampler)

    # set up model
    model = nn.Sequential(
            nn.Linear(10, 1, bias=True),
            nn.BatchNorm1d(1),
            nn.LeakyReLU(negative_slope=0.3, inplace=True)
            )
    model.cuda(rank_gpu)

    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank_gpu])

    # set up training
    criterion = nn.MSELoss().cuda(rank_gpu)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))
    start_epoch = 0
    train_sampler.set_epoch(start_epoch)
    model.train()

    for batch in train_loader:
        data = batch['data']
        labels = batch['labels']

        data = data.to(device=torch.device(rank_gpu), dtype=torch.float32, non_blocking=True)
        labels = labels.to(device=torch.device(rank_gpu), dtype=torch.float32, non_blocking=True)

        # forward pass
        preds = model(data)

        loss = criterion(preds, labels)
        batch_loss = loss.item()

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'end of training loop rank: {rank_gpu}')
    dist.destroy_process_group()


def main():
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = find_free_port()
    args = arg_parse()

    # devices are submitted as $CUDA_VISIBLE_DEVICES
    devices = args.devices[0].split(',')

    # random seeding
    seed = 7612873
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    dataset = my_dataset(80)
    world_size = len(devices)
    mp.spawn(train_model_distributed, nprocs=world_size, args=(world_size, dataset))
    print('finished ddp training!')


if __name__ == '__main__':
    main()

Edit: An excerpt from the log and the error message:

end of training loop rank: 0
end of training loop rank: 1
Traceback (most recent call last):
File "/cluster/home/USER/projects/project1/run_debugging.py", line 174, in <module>
¦ main()
File "/cluster/home/USER/projects/project1/run_debugging.py", line 169, in main
¦ mp.spawn(train_model_distributed, nprocs=world_size, args=(world_size, dataset))
File "/cluster/home/USER/.pyenv/versions/torch17/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
¦ return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/cluster/home/USER/.pyenv/versions/torch17/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
¦ while not context.join():
File "/cluster/home/USER/.pyenv/versions/torch17/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 105, in join
¦ raise Exception(
Exception: process 1 terminated with signal SIGSEGV

I train on two GPUs on a single node. As mentioned above, sometimes the above code executes correctly and sometimes with a SIGSEGV error. This holds true even if the code is repeatedly run on the same node (both in parallel and sequential).
I went through all the torch.distributed tutorials, many forum posts, GitHub issues, and exemplary implementations of DistributedDataParallel(). Nothing really caught my eye that helps.

I am looking forward to any potential solution, hint, advice,…

Cheers

2 Likes

I wasn’t able to reproduce this issue on my end unfortuantely. Could you ensure dist.destroy_process_group() is exiting cleanly and the SIGSEGV does not come from there?

Do your processes always exit cleanly if you remove the distributed init/DDP from your spawned subprocess? This may help narrow down the issue a bit more.

Thanks for the fast reply @rvarm1, I appreciate it.

  1. I think the SIGSEGV is coming from dist.destroy_process_group() since the print(f'end of training loop rank: {rank_gpu}') is logged, but not the print('finished ddp training!').
  2. I removed dist.init_process_group(), DistributedSampler(), SyncBatchNorm, and DistributedDataPrallel() in train_model_distributed(), but kept the mp.spawn() (if that is what you meant). My jobs have been queued for a while now. I will get back at you as soon as they complete.

UPDATE: SIGSEGV occurs if I remove all the distributed training parts (see 2.).

Hey @rvarm1,

So I have found a workaround that seems to work:

Running the above script my_script.py via

python -X dev my_script.py --devices $CUDA_VISIBLE_DEVICES

yielded the following error:
Fatal Python error: PyEval_SaveThread: the function must be called with the GIL held, but the GIL is released (the current Python thread state is NULL)

Searching for that error, I found the following bug report. I do not really understand the details of the therein referenced bug report, but the way python 3.9 handles GIL seems to cause the SIGSEGV when running mp.spawn().

Workaround: Downgrading to python 3.8.7 got rid of the SIGSEGV and my_script.py runs without any errors.

If you have further suggestions how to make the script work with python 3.9, I would be curious to test it.