Data Parallel throws Error: replica 1 on device 5 and CUDNN_STATUS_EXECUTION_FAILED

Currently trying to use pytorch data parallel to up batchsize as shown in code, but get 1) RuntimeError: Caught RuntimeError in replica 1 on device 5. and 2) RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED.

**Anyone who has worked with dataparallel before and had similar issues ? **

Code:

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device_ids = [1, 5]
model = YoloV4_EfficentNet(nclasses = arguments['nclasses']).to(device)
model = nn.DataParallel(model, device_ids = device_ids)

def trainyolov4(arguments, train_loader, model, optimizer, scheduler, loss_f, scaled_anchors, scaler, mode = 'iou'):
    model.train()
    for batch_idx, (x, y) in enumerate(train_loader):
        x = x.permute(0, 3, 1, 2) #.to(device)
        y0, y1, y2 = (y[0], y[1], y[2]) #.to(device)
        # x shape :-: (batchsize, channels, height, width)
        with autocast():
            preds = model(x)
            loss_val = (
                loss_f(preds[0], y0, scaled_anchors[0], mode = mode)
                + loss_f(preds[1], y1, scaled_anchors[1], mode = mode)
                + loss_f(preds[2], y2, scaled_anchors[2], mode = mode))
        class_acc, noobj_acc, obj_acc = class_accuracy(preds, y, arguments["conf_thresh"])
        optimizer.zero_grad()
        scaler.scale(loss_val).backward()
        scaler.step(optimizer)
        scaler.update()
        if arguments["one_cycle"] == True:
            scheduler.step()
    return (float(loss_val.item()), float(class_acc),float(noobj_acc), float(obj_acc))

Error :

  1. RuntimeError: Caught RuntimeError in replica 1 on device 5. and 2) RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED.
Traceback (most recent call last):
  File "main.py", line 54, in <module>
    train_model_with_args()
  File "main.py", line 52, in train_model_with_args
    initialize_with_args(arguments)
  File "/home/thesis/train/train.py", line 214, in initialize_with_args
    main(arguments)
  File "/home/thesis/train/train.py", line 695, in main
    train_loss_val, train_class_acc, train_noobj_acc, train_obj_acc = trainyolov4(arguments, train_loader, model, optimizer, scheduler, loss_f, scaled_anchors, scaler, mode = 'ciou')
  File "/home/thesis/train/train.py", line 376, in trainyolov4
    preds = model(x)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/.local/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/home/.local/lib/python3.8/site-packages/torch/_utils.py", line 543, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 5.
Original Traceback (most recent call last):
  File "/home/.local/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/thesis/models/yolov4.py", line 206, in forward
    sclaed_pred2 = self.yolov4head[1](panet_scale2)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/thesis/models/yolov4.py", line 64, in forward
    out = self.scaled_pred(x)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/thesis/models/yolov4.py", line 19, in forward
    out = self.bn(out)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward
    return F.batch_norm(
  File "/home/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 2450, in batch_norm
    return torch.batch_norm(
RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

Tried:

  1. Explicitly setting cuda visible devices, 2) set cuda launch blocking = 1 for debug, 3) moving x, y to device vs not doing it as data parallel should do this for us , 4) made sure CUDA version 11.6 on the labs servers matches pytorch version and 5) checked that GPUs are reconized.

So I gave torch a fresh install, and it now works but only if I disable the backend torch.backends.cudnn.enabled = False
If I use CUDA_LAUNCH_BLOCKING=1, the process freezes and needs to be manually killed using kill -9 PID. But if I set If I set the flag to True torch.backends.cudnn.enabled = True, then I get the following error:

Error

Traceback (most recent call last):
  File "main.py", line 54, in <module>
    train_model_with_args()
  File "main.py", line 52, in train_model_with_args
    initialize_with_args(arguments)
  File "/home/thesis/train/train.py", line 214, in initialize_with_args
    main(arguments)
  File "/home/thesis/train/train.py", line 699, in main
    train_loss_val, train_class_acc, train_noobj_acc, train_obj_acc = trainyolov4(arguments, train_loader, model, optimizer, scheduler, loss_f, scaled_anchors, scaler, mode = 'ciou')
  File "/home/thesis/train/train.py", line 378, in trainyolov4
    preds = model(x)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/.local/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
    output.reraise()
  File "/home/.local/lib/python3.8/site-packages/torch/_utils.py", line 705, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/.local/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/thesis/models/yolov4.py", line 206, in forward
    sclaed_pred2 = self.yolov4head[1](panet_scale2)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/thesis/models/yolov4.py", line 64, in forward
    out = self.scaled_pred(x)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/thesis/models/yolov4.py", line 21, in forward
    out = self.activation(out)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 790, in forward
    return F.leaky_relu(input, self.negative_slope, self.inplace)
  File "/home/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 1675, in leaky_relu
    result = torch._C._nn.leaky_relu(input, negative_slope)
RuntimeError: CUDA error: misaligned address
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@ptrblck any ideas?

If I disable the backend, I am essentially using torches own implementations ? Based on GPU memory usage, utility, and training process, the model is training even if the backend flag is set to false. Does this affect data parallel or is it taking effect ?

Are you able to reproduce the issue without DataParallel or only with it?
Also, are you using the latest stable or nightly release? If so, could you post a minimal and executable code snippet reproducing the issue?

Hey @ptrblck,

I abandoned the DP approach in favor of a DDP approach, which seems to work. (See code below.)
However, I was wondering whether it is common to use the DistributedSampler together with DDP ?
I am experiencing serious memory overhead, when trying to using the DistributedSampler together with DDP. I could fit a batchsize of 30 before on a single gpu without DDP and now I can only fit a batchsize of 16 on two gpus.
If I try to fit 30 on 2,3 or 4 gpus I obtain a CUDA OUT of MEMORY ERROR.

I am using the newest torch version, and have 8 RTX 3080 available.
Should I opt out of using the DistributedSampler and only use DDP ?

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


def main(rank, world_size):
    image_size = 608
    batch_size = 16
    setup(rank, world_size)
    # yolo anchors rescaled between 0,1
    # yolo scales and anchors for image size 608, 608
    S = [19, 38, 76]
    anchors = [
        [(0.23, 0.18), (0.32, 0.40), (0.75, 0.66)],
        [(0.06, 0.12), (0.12, 0.09), (0.12, 0.24)],
        [(0.02, 0.03), (0.03, 0.06), (0.07, 0.05)],
    ]

    train_dataset = CoCoDataset(
        "data/coco/train.csv",
        "data/coco/images/",
        "data/coco/labels/",
        S=S,
        anchors=anchors,
        image_size=image_size,
        mode="train",
    )

    test_dataset = test_dataset = CoCoDataset(
        "data/coco/test.csv",
        "data/coco/images/",
        "data/coco/labels/",
        S=S,
        anchors=anchors,
        image_size=image_size,
        mode="test",
    )

    scaled_anchors = (torch.tensor(anchors) * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)).to(rank)
    loss_f = YoloV4Loss2()

    sampler_train = DistributedSampler(train_dataset, shuffle=True)
    sampler_test = DistributedSampler(test_dataset, shuffle=False)

    # we drop the last batch to ensure each batch has the same size
    train_loader = DataLoader(
        dataset=train_dataset,
        num_workers=2,
        batch_size=batch_size,
        drop_last=False,
        sampler=sampler_train,
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        num_workers=2,
        batch_size=batch_size,
        drop_last=False,
        sampler=sampler_test,
    )
    model = YoloV4_EfficentNet(nclasses=80).to(rank)

    ddp_model = DDP(model, device_ids=[rank], gradient_as_bucket_view = True)
    optimizer = optim.Adam(
        ddp_model.parameters(), lr=1e-4, weight_decay=0.0005, betas=(0.937, 0.999)
    )
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1e-4,
        epochs=300,
        steps_per_epoch=len(train_loader),
        anneal_strategy="cos",
    )

    for epoch in range(300):
        # train
        print(f"Rank:{rank}")
        scaler = GradScaler()
        train_loss_val, train_class_acc, train_noobj_acc, train_obj_acc = trainyolov4(
            rank,
            train_loader,
            model,
            optimizer,
            scheduler,
            loss_f,
            scaled_anchors,
            scaler,
            conf_thresh=0.8,
            mode="ciou",
        )

        print(f"Epoch:{epoch + 1}  Train[Loss:{train_loss_val} Class Acc:{train_class_acc} NoObj acc:{train_noobj_acc} Obj Acc:{train_obj>

if __name__ == "__main__":
    mp.spawn(
        main,
        args=(2,),
        nprocs=2,
        join=True,
    )

Using the DistributedSampler is the common approach as it splits the dataset into shards for each rank making sure samples are not repeated on all ranks (with the exception of repeating samples to make the number of samples divisible by the batch size and ranks).
Note that DP will shard the input batch in dim0 from the default rank so are you sure every GPU used a batch size of 30 or was it shared? In the former case your batch size would be 30 * num_gpus.

If DP shards batch in dim0, then each gpu should have batchsize/n_gpus, right? This would mean for my case tath batchsize/n_gpus = 30 / 2 = 15 also tested with batchsize/n_gpus = int(30/4) = 7 on each gpu. The GPU memory load reaches its limits and then throws a cuda out of memory error.

Since it works without DP, I suspect the problem is related to either 2. or 3. How would I proceed to debug?

train_dataset = CoCoDataset( "data/coco/train.csv",
        "data/coco/images/",
        "data/coco/labels/",
        S=S,
        anchors=anchors,
        image_size=image_size,
        mode="train",
    )
sampler_train = DistributedSampler(train_dataset, shuffle=True)
    train_loader = DataLoader(
        dataset=train_dataset,
        num_workers=2,
        batch_size=32,
        drop_last=False,
        sampler=sampler_train,
    )
  1. Snipped after:
    model = YoloV4_EfficentNet(nclasses=80).to(rank)

    ddp_model = DDP(model, device_ids=[rank], gradient_as_bucket_view = True)
    optimizer = optim.Adam(
        ddp_model.parameters(), lr=1e-4, weight_decay=0.0005, betas=(0.937, 0.999)
    )
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1e-4,
        epochs=300,
        steps_per_epoch=len(train_loader),
        anneal_strategy="cos",
    )

Just to make sure we are talking about the same configs and expectations.

The deprecated DataParallel (DP) module allocates a single batch on the default device and shards it in the forward pass in dim0. The output is collected again on the default device and will thus create a memory imbalance. I.e. if you are using 8 GPUs and a batch size of 32, each device will use 32/8=4 only (with the exception of the default device where the entire input and output are collected).

DDP on the other hand uses a single process per device and the defined batch size is also the used one on each device (or rank). If you want to use the same batch size as before with DP, you would need to decrease it and divide by the number of ranks. DDP is considered superior as it’s not causing a memory imbalance and reduces the communication between devices.

I’m currently unsure which config worked when, but you should double check what the actual batch size was.

Based on your code snippets it seems the batch size on each rank is set to 30, so you are increasing it in the DDP use case.

I abondoned the DP approached infavor of DDP, after reading up on the DP vs DDP and what is the standard way to do distributed parallel. I was not entirely clear on that DDP did not shard the dataset at dim0, but run its own process on each rank or GPU.

So clarifiy, what you describe then means that:

  1. DP shards the dataset at dim0 and allocating a batchsize of batchsize / n_gpus to each gpu
  2. DDP on the other hand does not shard the dataset but runs a single process on each gpu, meaning that it will run a batchsize of batchsize * 1 on each gpu or a total batchsize = batch size * n_gpus

So in my case as I am using DDP, we given n_gpus = 2, and a batchsize = 30, each GPU has a batchsize of 30, and thus the total batchsize is 30* 2. This is too large, I can however fit a batchsize of 16, meaning 16 on each GPU process and thus, 32 = 16 * 2 in total.

Yes, your explanation is correct. In addition DDP uses a DistributedSampler to make sure each rank only loads and processes its corresponding samples to avoid duplications. This was not needed in the (deprecated) DP approach since only a single process was loading all samples.

1 Like