Torchrun Error!

Hi!
I am trying to train a language model on multiple GPUs using Torchrun by torchrun --standalone --nproc_per_node=gpu multigpu_torchrun.py. I cannot get the output of the model. Here is the code and the error I receive. I appreciate it if anyone could help me.

import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

from torch.optim import AdamW


from datareader import get_dataset
from modelreader import get_model

def ddp_setup():
    init_process_group(backend="nccl")

class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        save_every: int,
        snapshot_path: str,
    ) -> None:
        self.gpu_id = int(os.environ["LOCAL_RANK"])
        self.model = model.to(self.gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.epochs_run = 0
        self.snapshot_path = snapshot_path
        if os.path.exists(snapshot_path):
            print("Loading snapshot")
            self._load_snapshot(snapshot_path)

        self.model = DDP(self.model, device_ids=[self.gpu_id])

    def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.gpu_id}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    def _run_batch(self, input_ids, attention_mask):
        self.optimizer.zero_grad()
        print(type(self.model))
        outputs = self.model( input_ids, attention_mask = attention_mask, labels=input_ids)
        print('output calculated')
        loss = outputs.loss
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data)))
        # print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for batch in self.train_data:
            input_ids = batch['input_ids'].to(self.gpu_id)
            attention_mask = batch['attention_mask'].to(self.gpu_id)
            self._run_batch(input_ids, attention_mask)

    def _save_snapshot(self, epoch):
        snapshot = {
            "MODEL_STATE": self.model.module.state_dict(),
            "EPOCHS_RUN": epoch,
        }
        torch.save(snapshot, self.snapshot_path)
        print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")

    def train(self, max_epochs: int):
        for epoch in range(self.epochs_run, max_epochs):
            self._run_epoch(epoch)
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)


def load_train_objs(model_type, lr):
    train_set, valid_set = get_dataset(model_type)
    model = get_model(model_type)
    optimizer = AdamW(model.parameters(), lr=lr)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset)
    )


def main(args):
    ddp_setup()
    dataset, model, optimizer = load_train_objs(args.model_type, args.lr)
    train_data = prepare_dataloader(dataset, args.batch_size)
    trainer = Trainer(model, train_data, optimizer, args.save_every, args.snapshot_path)
    trainer.train(args.num_epoch)
    destroy_process_group()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Distributed training GPTs')
    parser.add_argument('--total_epochs', type=int, default = 1, help='Total epochs to train the model')
    parser.add_argument('--save_every', type=int, default=1, help='How often to save a snapshot')
    parser.add_argument("--lr", type=float, default=5e-4, help="Training learning rate")
    parser.add_argument("--num_epoch", type=int, default=20, help="Training epoch number")
    parser.add_argument('--batch_size', default=2, type=int, help='Input batch size on each device (default: 32)')
    parser.add_argument("--model_type", default='gpt2', type=str, required=False)
    parser.add_argument("--snapshot_path", type=str, default="snapshot.pt", help="path to save snapshot")

    args = parser.parse_args()
    
    main(args)

Here is the error::

../aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [391,0,0], thread: [61,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [391,0,0], thread: [62,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:975: indexSelectLargeIndex: block: [391,0,0], thread: [63,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Exception raised from query at ../aten/src/ATen/cuda/CUDAEvent.h:91 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x3e (0x7fa4f646b20e in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x13c (0x7fa539dc1cec in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #2: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7fa539dc3cc8 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #3: c10d::ProcessGroupNCCL::workCleanupLoop() + 0x221 (0x7fa539dc5251 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #4: <unknown function> + 0xcda93 (0x7fa54b5eea93 in /opt/conda/lib/python3.7/site-packages/torch/lib/../../../../libstdc++.so.6)
frame #5: <unknown function> + 0x7fa3 (0x7fa564d6cfa3 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #6: clone + 0x3f (0x7fa564b0506f in /lib/x86_64-linux-gnu/libc.so.6)

ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 27615) of binary: /opt/conda/bin/python3
Traceback (most recent call last):
  File "/opt/conda/bin/torchrun", line 10, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
    return f(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/run.py", line 761, in main
    run(args)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/run.py", line 755, in run
    )(*cmd_args)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 247, in launch_agent
    failures=result.failures,
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
======================================================
multigpu_torchrun.py FAILED
------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-05-01_04:30:15
  host      : pt-a100.c.textgen-fine-exp.internal
  rank      : 0 (local_rank: 0)
  exitcode  : -6 (pid: 27615)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 27615
======================================================

An indexing operation fails based on the error message. You could rerun the code with blocking launches as suggested in the error message or on the CPU to get a better stacktrace pointing to the failing operation.

Thanks for your reply. Can you give me a bit more information and go into more detail, please?
I set CUDA_LAUNCH_BLOCKING=1

../aten/src/ATen/native/cuda/Indexing.cu:1141: indexSelectLargeIndex: block: [34,0,0], thread: [29,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1141: indexSelectLargeIndex: block: [34,0,0], thread: [30,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1141: indexSelectLargeIndex: block: [34,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/jupyter/finetune-5-gpu2/multigpu_torchrun.py:119 in <module>                               │
│                                                                                                  │
│   116 │                                                                                          │
│   117 │   args = parser.parse_args()                                                             │
│   118 │                                                                                          │
│ ❱ 119 │   main(args)                                                                             │
│   120                                                                                            │
│                                                                                                  │
│ /home/jupyter/finetune-5-gpu2/multigpu_torchrun.py:102 in main                                   │
│                                                                                                  │
│    99 │   dataset, model, optimizer = load_train_objs(args.model_type, args.lr)                  │
│   100 │   train_data = prepare_dataloader(dataset, args.batch_size)                              │
│   101 │   trainer = Trainer(model, train_data, optimizer, args.save_every, args.snapshot_path)   │
│ ❱ 102 │   trainer.train(args.num_epoch)                                                          │
│   103 │   destroy_process_group()                                                                │
│   104                                                                                            │
│   105                                                                                            │
│                                                                                                  │
│ /home/jupyter/finetune-5-gpu2/multigpu_torchrun.py:75 in train                                   │
│                                                                                                  │
│    72 │                                                                                          │
│    73 │   def train(self, max_epochs: int):                                                      │
│    74 │   │   for epoch in range(self.epochs_run, max_epochs):                                   │
│ ❱  75 │   │   │   self._run_epoch(epoch)                                                         │
│    76 │   │   │   if self.gpu_id == 0 and epoch % self.save_every == 0:                          │
│    77 │   │   │   │   self._save_snapshot(epoch)                                                 │
│    78                                                                                            │
│                                                                                                  │
│ /home/jupyter/finetune-5-gpu2/multigpu_torchrun.py:63 in _run_epoch                              │
│                                                                                                  │
│    60 │   │   for batch in self.train_data:                                                      │
│    61 │   │   │   input_ids = batch['input_ids'].to(self.gpu_id)                                 │
│    62 │   │   │   attention_mask = batch['attention_mask'].to(self.gpu_id)                       │
│ ❱  63 │   │   │   self._run_batch(input_ids, attention_mask)                                     │
│    64 │                                                                                          │
│    65 │   def _save_snapshot(self, epoch):                                                       │
│    66 │   │   snapshot = {                                                                       │
│                                                                                                  │
│ /home/jupyter/finetune-5-gpu2/multigpu_torchrun.py:51 in _run_batch                              │
│                                                                                                  │
│    48 │                                                                                          │
│    49 │   def _run_batch(self, input_ids, attention_mask):                                       │
│    50 │   │   self.optimizer.zero_grad()                                                         │
│ ❱  51 │   │   outputs = self.model( input_ids, attention_mask = attention_mask, labels=input_i   │
│    52 │   │   loss = outputs.loss                                                                │
│    53 │   │   loss.backward()                                                                    │
│    54 │   │   self.optimizer.step()                                                              │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py:1194 in _call_impl             │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/torch/nn/parallel/distributed.py:1040 in forward          │
│                                                                                                  │
│   1037 │   │   │   │   # Notify joined ranks whether they should sync in backwards pass or not.  │
│   1038 │   │   │   │   self._check_global_requires_backward_grad_sync(is_joined_rank=False)      │
│   1039 │   │   │                                                                                 │
│ ❱ 1040 │   │   │   output = self._run_ddp_forward(*inputs, **kwargs)                             │
│   1041 │   │   │                                                                                 │
│   1042 │   │   │   # sync params according to location (before/after forward) user               │
│   1043 │   │   │   # specified as part of hook, if hook was specified.                           │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/torch/nn/parallel/distributed.py:1000 in _run_ddp_forward │
│                                                                                                  │
│    997 │   │   │   │   self.use_side_stream_for_tensor_copies                                    │
│    998 │   │   │   )                                                                             │
│    999 │   │   │   with self._inside_ddp_forward():                                              │
│ ❱ 1000 │   │   │   │   return module_to_run(*inputs[0], **kwargs[0])                             │
│   1001 │   │   else:                                                                             │
│   1002 │   │   │   with self._inside_ddp_forward():                                              │
│   1003 │   │   │   │   return module_to_run(*inputs, **kwargs)                                   │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py:1194 in _call_impl             │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/transformers/models/gpt2/modeling_gpt2.py:1056 in forward │
│                                                                                                  │
│   1053 │   │   │   use_cache=use_cache,                                                          │
│   1054 │   │   │   output_attentions=output_attentions,                                          │
│   1055 │   │   │   output_hidden_states=output_hidden_states,                                    │
│ ❱ 1056 │   │   │   return_dict=return_dict,                                                      │
│   1057 │   │   )                                                                                 │
│   1058 │   │   hidden_states = transformer_outputs[0]                                            │
│   1059                                                                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py:1194 in _call_impl             │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/transformers/models/gpt2/modeling_gpt2.py:895 in forward  │
│                                                                                                  │
│    892 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                          │
│    893 │   │   │   │   │   encoder_attention_mask=encoder_attention_mask,                        │
│    894 │   │   │   │   │   use_cache=use_cache,                                                  │
│ ❱  895 │   │   │   │   │   output_attentions=output_attentions,                                  │
│    896 │   │   │   │   )                                                                         │
│    897 │   │   │                                                                                 │
│    898 │   │   │   hidden_states = outputs[0]                                                    │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py:1194 in _call_impl             │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/transformers/models/gpt2/modeling_gpt2.py:394 in forward  │
│                                                                                                  │
│    391 │   │   │   attention_mask=attention_mask,                                                │
│    392 │   │   │   head_mask=head_mask,                                                          │
│    393 │   │   │   use_cache=use_cache,                                                          │
│ ❱  394 │   │   │   output_attentions=output_attentions,                                          │
│    395 │   │   )                                                                                 │
│    396 │   │   attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)            │
│    397 │   │   outputs = attn_outputs[1:]                                                        │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py:1194 in _call_impl             │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/transformers/models/gpt2/modeling_gpt2.py:310 in forward  │
│                                                                                                  │
│    307 │   │   │   key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2  │
│    308 │   │   │   attention_mask = encoder_attention_mask                                       │
│    309 │   │   else:                                                                             │
│ ❱  310 │   │   │   query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)  │
│    311 │   │                                                                                     │
│    312 │   │   query = self._split_heads(query, self.num_heads, self.head_dim)                   │
│    313 │   │   key = self._split_heads(key, self.num_heads, self.head_dim)                       │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py:1194 in _call_impl             │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /opt/conda/lib/python3.7/site-packages/transformers/pytorch_utils.py:115 in forward              │
│                                                                                                  │
│   112 │                                                                                          │
│   113 │   def forward(self, x):                                                                  │
│   114 │   │   size_out = x.size()[:-1] + (self.nf,)                                              │
│ ❱ 115 │   │   x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)                    │
│   116 │   │   x = x.view(size_out)                                                               │
│   117 │   │   return x                                                                           │
│   118                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 17651) of binary: /opt/conda/bin/python3
Traceback (most recent call last):
  File "/opt/conda/bin/torchrun", line 10, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/run.py", line 762, in main
    run(args)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/run.py", line 756, in run
    )(*cmd_args)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 248, in launch_agent
    failures=result.failures,
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
multigpu_torchrun.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-05-01_16:26:03
  host      : pytorch-1-13-20230308-211033.c.textgen-fine-exp.internal
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 17651)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

I’m unsure which details are missing as I have suggested debugging steps and don’t have any code to execute which I could debug.
The failing operation is an indexing operation which might be raised by e.g. an embedding layer, but you would have to narrow down the actually failing operation by blocking launches or on the CPU.