Torch.compile + DDP: SIGSEGV/SIGTERM during inference step

Hi,

I am writing a training harness from scratch for work that involves iterative pruning – which uses DDP train each level.

tl;dr
SIGTERM/SIGSEGV while running inference during a DDP run + model which has been torch.compile’d.

Exact error:

W0526 00:24:17.229000 22419848091456 torch/multiprocessing/spawn.py:145] Terminating processP ROCESS_ID via signal SIGTERM
Traceback (most recent call last):
  File "/harness.py", line 346, in <module>
    mp.spawn(main, args=(init_model, world_size, threshold, i), nprocs=world_size, join=True)
  File "/orch/multiprocessing/spawn.py", line 281, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "torch/multiprocessing/spawn.py", line 237, in start_processes
    while not context.join():
  File "torch/multiprocessing/spawn.py", line 169, in join
    raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with signal SIGSEGV

Details:

To describe my setup:

  1. Model: ResNet-18/50 (with a mask per layers, same size as weight parameters as buffers).
  2. Batch-Size: 256 (each)
  3. Autocast: BFloat16
  4. Dataset: ImageNet (224x224)
  5. FFCV for dataloading (but for all purposes this shouldn’t matter)
  6. torch.compile (mode = ‘reduce-overhead’) applied before passing the model to DDP as below:
    Tried with:
  7. Different PyTorch versions: 2.0.1, 2.1, 2.2.1, 2.2.2 2.3
  8. CUDA: 11.8 and 12.1
self.model = torch.compile(self.model, mode='reduce-overhead')
self.model = DDP(self.model, device_ids=[self.gpu_id])
  1. 2x A100 GPUs (80 GB) with abundant memory/CPU resources.
  2. num workers (if relevant): 8

I launch training as follows:

for i, threshold in enumerate(thresholds):
        mp.spawn(main, args=(init_model, world_size, threshold, i), nprocs=world_size, join=True)

I can make no sense of this issue:
At random (but consistently across attempts to train): DDP throws a SIGSEGV and exits with a SIGTERM when 33/97 batches have been evaluated or say 10/97 batches have been evaluated etc.,

At a loss why this is happening, plenty of GPU vRAM, local memory etc., is available – do not get any more information, it simply exits.

EDIT
Additional Context: When I only run 1-2 batches of training, break and then run the test loop – it completes as expected. i.e. when one full epoch of training is run, and subsequently I run inference the exception occurs.

I have done the following:

  1. Experimented with the aforementioned pytorch versions/CUDA versions.
  2. Machines: GCP and a local cluster (same hardware setup)

My test function is very simple:

def test(self):
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0

        tloader = tqdm.tqdm(self.test_loader, desc='Testing')
        with torch.no_grad():
            for inputs, targets in tloader:
                with autocast(dtype=torch.bfloat16):
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        test_loss /= (len(self.test_loader))
        accuracy = 100. * correct / total

        return test_loss, accuracy

When I run this without torch.compile, all runs fine :confused:

Hi, I’d like to revisit this, do you have any thoughts @ptrblck?

Training ResNet50 on ImageNet, batch size per gpu 128 on 4xA100 (using FFCV as the dataloader).

Few changes since this post:

  • changed from mp.spawn to torchrun
  • updated the test loop

(1)torch.compile(model, ‘reduce-overhead’) + DDP + torch 2.5 → slowly grows in memory and throws OOM and exits
(2)
torch.compile(model, ‘reduce-overhead’/‘max-autotune’) + DDP other versions (even 2.4 etc.,) → SIGTERM at inference
File “/opt/conda/bin/torchrun”, line 33, in
sys.exit(load_entry_point(‘torch==2.2.1’, ‘console_scripts’, ‘torchrun’)())
But installed version is not 2.2.1…
(3) DDP with (no torch.compile) → runs fine

additional information:

  • I add a buffer at every layer, to act as a mask during the forward pass - before compiling of course and to ensure it works fine I added:
  • torch._dynamo.config.guard_nn_modules=True (so that the mask got updated)
  • Tried compiling before DDP and after, to same effect.

After a few batches on the test set, it just dies…

Really not sure what to do here.

I don’t know what might be causing the issues and also don’t fully understand your use case as it seems you are trying to launch multiple DDP jobs at once inside this loop?

I would recommend checking why this old PyTorch release is used e.g. by checking the torch.__path__ making sure the right binary is used.

This is often a user error where e.g. the loss is accumulated or attached to a list thus disallowing PyTorch to delete intermdiate activations and causing an OOM eventually.

1 Like

Hi!

(1) Regarding use-case
My use case is iterative magnitude pruning.

  • I initialize a model where each layer has a buffer for a mask, initially all set to 1.
  • The mask is applied at the forward pass, at level 0 – this has no impact (other than some overhead assumedly)
  • I train the dense model using DDP (just one process)
  • Then once complete, prune a % of the weights (set them to 0 in the corresponding buffer)
  • Release the distributed training group.
  • Initialize and start next level of training
  • Repeat till sparsity target is reached.
    The challenge initially was that torch.compile wasn’t detecting the change in the buffer, which is why I had to add torch._dynamo.config.guard_nn_modules=True at the start.

So at any given point, I just spawn a single instance of DDP.

(2) Version

  • I tried it 2-3 times, but I will try it with a different docker container which starts from scratch with 2.4/2.5 and will check – thanks!

(3) Regarding the OOM
Will look into it, in the train as well as test loop, I make sure to include only the .item() of the loss. But this could very well be the reason, my confusion also stemmed from the fact that this doesn’t happen on earlier versions of torch, and when torch.compile is not used.

(3) could be user-error, I’ll make changes to see what I’m doing wrong ^^

I’ll keep this updated, thanks for the suggestions!

Okay I got it working with with PyTorch 2.4 –

The challenge was potentially how I was keeping track of the loss values in a tensor as I called all_reduce on it. This was _potentially_causing the blowup and subsequent OOM leading to a SIGTERM at inference. This has now been dealt with, I delete the tensor after every epoch.

Thanks for pointing that out.

However, I think the problem that memory use is growing with each batch processed seems to persist with PT 2.5, though I will run a few ablations to check if that is always the case.

In any case – appreciate the support!

1 Like