Combining torch.compile and DistributedDataParallel

Hi,

I constantly run into an exception when I try to get DistributedDataParallel working.

This is how I setup the both:

self.model.to(self.rank)
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank])
print("Compiling: Reduced overhead")
self.model = torch.compile(self.model, mode="reduce-overhead").to(self.rank)
dataset = TorchVolumeDataset(volumes=volume_data)
sampler_data = torch.utils.data.DistributedSampler(dataset, rank=self.rank)
volume_loader = DataLoader(
      dataset=dataset,
      batch_size=self.batchsize,
      shuffle=False,
      num_workers=self.workers,
      pin_memory=True,
      sampler=sampler_data
  )

model = self.model.to(self.rank)
model.eval()
volume_loader_tqdm = tqdm(volume_loader, desc="Calculate embeddings", leave=False)
embeddings = []
from torch.cuda.amp import autocast
with torch.no_grad():
    for batch, item_index in volume_loader_tqdm:
        subvolume = batch["volume"]
        subvolume = subvolume.to(self.rank)
        with autocast():
           print("Batch to GPU", self.rank) 
           subvolume_out = model.forward(subvolume).data.cpu()
        embeddings.append(subvolume_out)

While this implementation is not complete to allow distributed inference, it is at least running if I don’t do the compile step. If I do the compile step, it crashes shortly after the first batch was transfered to GPU 1:

alculate embeddings:   0%|                                                                                                                                                                                       | 0/38392 [00:00<?, ?it/s]Batch to GPU 1
Batch to GPU 0
Calculate embeddings:   0%|                                                                                                                                                                           | 1/38392 [00:11<124:11:39, 11.65s/it]Batch to GPU 0
Batch to GPU 0
Calculate embeddings:   0%|                                                                                                                                                                            | 3/38392 [00:11<32:52:06,  3.08s/it]Batch to GPU 0
Batch to GPU 0
Calculate embeddings:   0%|                                                                                                                                                                            | 5/38392 [00:11<16:25:58,  1.54s/it]Batch to GPU 0
Batch to GPU 0
Calculate embeddings:   0%|                                                                                                                                                                             | 7/38392 [00:12<9:51:14,  1.08it/s]Batch to GPU 0
Batch to GPU 0
Calculate embeddings:   0%|                                                                                                                                                                             | 9/38392 [00:12<6:28:42,  1.65it/s]Batch to GPU 0
Batch to GPU 0
Calculate embeddings:   0%|                                                                                                                                                                            | 11/38392 [00:12<4:31:31,  2.36it/s]Batch to GPU 0
Batch to GPU 0
Calculate embeddings:   0%|                                                                                                                                                                            | 13/38392 [00:12<3:19:06,  3.21it/s]Batch to GPU 0
Traceback (most recent call last):
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/bin/tomotwin_embed.py", line 33, in <module>
    sys.exit(load_entry_point('tomotwin-cryoet', 'console_scripts', 'tomotwin_embed.py')())
  File "/mnt/data/twagner/Projects/TomoTwin/src/tomotwin-github/tomotwin/embed_main.py", line 593, in _main_
    mp.spawn(
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 328, in cudagraphify_impl
    static_outputs = model(list(static_inputs))
  File "/tmp/torchinductor_twagner/fb/cfbws6csr4pmbsivca3hsjk5qtywkmbgvala4e37ljkcinsthzxq.py", line 467, in call
    buf3 = aten.convolution(buf0, buf1, None, (1, 1, 1), (0, 0, 0), (1, 1, 1), False, (0, 0, 0), 1)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_ops.py", line 502, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/mnt/data/twagner/Projects/TomoTwin/src/tomotwin-github/tomotwin/embed_main.py", line 571, in run
    embed_tomogram(tomo, embedor, conf, window_size, mask)
  File "/mnt/data/twagner/Projects/TomoTwin/src/tomotwin-github/tomotwin/embed_main.py", line 505, in embed_tomogram
    embeddings = sliding_window_embedding(tomo=tomo, boxer=boxer, embedor=embedor)
  File "/mnt/data/twagner/Projects/TomoTwin/src/tomotwin-github/tomotwin/embed_main.py", line 407, in sliding_window_embedding
    embeddings = embedor.embed(volume_data=boxes)
  File "/mnt/data/twagner/Projects/TomoTwin/src/tomotwin-github/tomotwin/modules/inference/embedor.py", line 520, in embed
    subvolume_out = model.forward(subvolume).data.cpu()
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/data/twagner/Projects/TomoTwin/src/tomotwin-github/tomotwin/modules/networks/SiameseNet3D.py", line 522, in forward
    def forward(self, inputtensor):
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "<eval_with_key>.54", line 5, in forward
    submod_0 = self.compiled_submod_0(inputtensor);  inputtensor = None
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_dynamo/backends/distributed.py", line 247, in forward
    x = self.submod(*args)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2836, in forward
    return compiled_fn(full_args)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1224, in g
    return f(*args)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1900, in runtime_wrapper
    all_outs = call_func_with_args(
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1249, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 248, in run
    return model(new_inputs)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 265, in run
    compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 327, in cudagraphify_impl
    with torch.cuda.graph(graph, stream=stream):
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/cuda/graphs.py", line 173, in __exit__
    self.cuda_graph.capture_end()
  File "/opt/user_software/miniconda3_envs/tomotwin_pt2/lib/python3.10/site-packages/torch/cuda/graphs.py", line 79, in capture_end
    super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Not sure what is going on. Maybe someone has a hint for me? :slight_smile:

Best,
Thorsten

Could you compile the model before warpping it into DDP? Based on the error redue-overhead tries to apply CUDA Graphs and fails as disallowed operations are captured.

Thanks for your response!

I tried it:

 self.model.to(self.rank)

        print("Compiling! Reduced overhead")
        self.model = torch.compile(self.model, mode="reduce-overhead").to(self.rank)
        self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank])

Same error :frowning:

Isnt it strange that it only happens to gpu 1? Gpu 0 is continuing for a few batches until it crashes because of gpu 1.

Are you using the nightly binaries? If not, could you try these and check if you would run into the same issue still?

Hi, not yet! I will give it a try :slight_smile:

1 Like

Fixed with:

pytorch 2.1.0.dev20230719 py3.10_cuda11.8_cudnn8.7.0_0 pytorch-nightly

Thank you so much for your help! If I could buy you a coffee/beer I would do it :slight_smile:

Any idea when 2.1.0 gets released?

Haha, good to hear it’s working now and happy to help! :slight_smile:
2.1.0 should be released ~Oct 2023.

1 Like