CUDNN_STATUS_MAPPING_ERROR or misaligned address with pertained model using multiple gpu (DataParallel) half precision

I am getting cuda run time errors depends on the number of gpu used. It looks like the problem occurs when I am training customised model, which uses pretained Resnet34 as a part, on multiple gpus. Couldn’t really find out what was the problem.

I have 3 models:
Model A: customised model
Model B: pretrained Resnet34
Model C: Model A + some_linear_layers + Model B

I am on:
NVIDIA-SMI 450.51.06
Driver Version: 450.51.06
CUDA Version: 11.2
GPU: 4GPU-P100
Torch: 1.8.1+cu111
Torchvision: 0.9.1+cu111
Python: 3.8.10

I am able to run Model A batch size 16 with 4 GPU, with below usage:

|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000001:00:00.0 Off |                    0 |
| N/A   34C    P0    74W / 250W |  10133MiB / 16280MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000002:00:00.0 Off |                    0 |
| N/A   33C    P0   114W / 250W |  10093MiB / 16280MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla P100-PCIE...  Off  | 00000003:00:00.0 Off |                    0 |
| N/A   32C    P0   110W / 250W |  10093MiB / 16280MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla P100-PCIE...  Off  | 00000004:00:00.0 Off |                    0 |
| N/A   35C    P0   104W / 250W |  10093MiB / 16280MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

I am also able to fine tune Model B batch size 16 with 4 GPU, with below usage:

|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000001:00:00.0 Off |                    0 |
| N/A   34C    P0   133W / 250W |   3535MiB / 16280MiB |     83%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000002:00:00.0 Off |                    0 |
| N/A   31C    P0   131W / 250W |   3357MiB / 16280MiB |     50%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla P100-PCIE...  Off  | 00000003:00:00.0 Off |                    0 |
| N/A   30C    P0    79W / 250W |   1559MiB / 16280MiB |     50%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla P100-PCIE...  Off  | 00000004:00:00.0 Off |                    0 |
| N/A   34C    P0    67W / 250W |   1589MiB / 16280MiB |     57%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

However, when it comes to Model C (Model A + some_linear_layers + Model B ) batch size 16 4 GPU, I am having Runtime Error:

Traceback (most recent call last):
  File "train_selfvit.py", line 139, in <module>
    main(config)
  File "train_selfvit.py", line 78, in main
    trainer.train()
  File "/attentionKPTs/base/base_trainer.py", line 71, in train
    result = self._train_epoch(epoch)
  File "/attentionKPTs/trainer/trainer.py", line 70, in _train_epoch
    street_cls, sat_cls, shift_out, angle_out = self.model(sat, street, self.device)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 167, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 177, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/usr/local/lib/python3.8/dist-packages/torch/_utils.py", line 429, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/attentionKPTs/model/model.py", line 327, in forward
    sat_feature = self.sat_processor(sat_im_feature).unsqueeze(1)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py", line 249, in forward
    return self._forward_impl(x)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py", line 240, in _forward_impl
    x = self.layer4(x)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py", line 74, in forward
    out = self.conv2(out)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py", line 399, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py", line 395, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: CUDA error: misaligned address

But I don’t think it is OOM problem, as if I run Model C batch size 4 or 6 on 1 GPU it got sufficient memory:
For bs=4

|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000001:00:00.0 Off |                    0 |
| N/A   36C    P0   100W / 250W |  10555MiB / 16280MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000002:00:00.0 Off |                    0 |
| N/A   28C    P0    25W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

For bs=6

|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000001:00:00.0 Off |                    0 |
| N/A   35C    P0   138W / 250W |  15143MiB / 16280MiB |     93%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000002:00:00.0 Off |                    0 |
| N/A   28C    P0    25W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

When try to train batch size 4 on 2 GPU got RuntimeError: CUDA error: misaligned address when calculating the loss.

However,
When try batch size 4 on 4 GPU, it could run with low GPU usage:

|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000001:00:00.0 Off |                    0 |
| N/A   31C    P0    47W / 250W |   3861MiB / 16280MiB |     67%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000002:00:00.0 Off |                    0 |
| N/A   29C    P0    65W / 250W |   3593MiB / 16280MiB |     65%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla P100-PCIE...  Off  | 00000003:00:00.0 Off |                    0 |
| N/A   29C    P0    77W / 250W |   3621MiB / 16280MiB |     65%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla P100-PCIE...  Off  | 00000004:00:00.0 Off |                    0 |
| N/A   32C    P0    69W / 250W |   3593MiB / 16280MiB |     43%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

Any thought?

Could you post a minimal, executable code snippet to reproduce the issue, please?
Also, could you update to the latest stable release (1.9.0) or the nightly binary and rerun your code?

@ptrblck I went back to review my code and update my python to 1.9.0. I did a mistake that did not call @autocast() in the forward function. But that was not the solution for my problem.

The problem was caused by the order how I define each parts of my model, which is weird and I still don’t know why it could be an issue.

Here I create some mini code to simulate the situation:

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from torchvision import models
from torch.nn import init
from einops.layers.torch import Rearrange
import argparse


class Resnet(nn.Module):
    def __init__(self):
        super(). __init__()
        self.resnet = models.resnet34(pretrained=True, progress=True)

    @autocast()
    def forward(self, x):
        out = self.resnet(x).unsqueeze(1)
        return out

class MixResnet(nn.Module):
    def __init__(self):
        super(). __init__()
        
        self.resnet = models.resnet34(pretrained=True, progress=True)
        self.processor = CustomizedHead()
        self.linkage = nn.Sequential(
                        Rearrange('b c h w-> b (h w) c'),
                        nn.Linear(5, 3),
                        nn.Tanh(),
                        Rearrange('b (h w) c -> b c h w', h=256, w=256)
                    )
        
        self.linkage.apply(weights_init_kaiming)

    @autocast()
    def forward(self, x):
        input_x = self.processor(x)
        input_x = self.linkage(input_x)
        out = self.resnet(input_x).unsqueeze(1)
        return out

class MixResnetReverse(nn.Module):
    def __init__(self):
        super(). __init__()
        self.linkage = nn.Sequential(
                        Rearrange('b c h w-> b (h w) c'),
                        nn.Linear(5, 3),
                        nn.Tanh(),
                        Rearrange('b (h w) c -> b c h w', h=256, w=256)
                    )
        
        self.resnet = models.resnet34(pretrained=True, progress=True)
        self.processor = CustomizedHead()
        
        
        self.linkage.apply(weights_init_kaiming)

    @autocast()
    def forward(self, x):
        input_x = self.processor(x)
        input_x = self.linkage(input_x)
        out = self.resnet(input_x).unsqueeze(1)
        return out


class CustomizedHead(nn.Module):
    def __init__(self,):
        super().__init__()
        self.token_depth = 10
        self.processor = nn.Sequential(
                        Rearrange('b c h w -> b (h w) c'),
                        nn.Linear(self.token_depth, 7),
                        nn.Linear(7, 5),
                        nn.Tanh(),
                        Rearrange('b (h w) c -> b c h w', h=256, w=256)
                    )
        self.processor.apply(weights_init_kaiming)

    def forward(self, x):
        batch_size = x.shape[0]
        token = torch.rand(batch_size, self.token_depth - x.shape[1], x.shape[2], x.shape[3]).to(x.device)
        input_x = torch.cat((token, x), 1)
        return self.processor(input_x)


class Customized(nn.Module):
    def __init__(self):
        super().__init__()
        self.processor = CustomizedHead()
        self.tail = nn.Sequential(
                        nn.AvgPool2d((16, 16)),
                        Rearrange('b c h w -> b 1(h w c)', h=16, w=16),
                        nn.Linear(1280, 1000),
                        nn.Sigmoid()
                    )
        self.tail.apply(weights_init_kaiming)            

    @autocast()
    def forward(self, x):
        mid_x = self.processor(x)
        out = self.tail(mid_x)
        return out


def main(args):
    if args.arch=='Resnet':
        model = Resnet()
    elif args.arch=='MixResnet':
        model = MixResnet()
    elif args.arch=='MixResnetReverse':
        model = MixResnetReverse()
    else:
        model = Customized()
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adam(trainable_params)
    loss_ftn = nn.MSELoss()
    scaler = GradScaler()

    device = torch.device('cuda:0' if args.n_gpu > 0 else 'cpu')
    list_ids = list(range(args.n_gpu))
    model = model.to(device)
    if len(list_ids)>1:
        model = torch.nn.DataParallel(model, device_ids=list_ids)
    
    model.train()
    for iter in range(10):
        x = torch.rand(args.batch_size, 3, 256, 256).to(device)
        y = torch.rand(args.batch_size, 3, 256, 256).to(device)
        gt = torch.rand(args.batch_size, 1, 1000).to(device)
        with autocast():
            optimizer.zero_grad()
            if args.arch == 'Mixed_Resnet34':
                output, _,_,_ = model(x, y)
            else:
                output = model(x)
            loss = loss_ftn(output, gt)
        print(f"Iter: {iter}, loss: {loss.item()}")
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    
    print('finish!')

def weights_init_kaiming(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv2d') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal.
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
        if m.bias is not None:
            init.constant_(m.bias.data, 0.0)
    elif classname.find('InstanceNorm') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)
    elif classname.find('LayerNorm') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--n_gpu', default=4, type=int)
    parser.add_argument('--arch', default='Resnet', type=str)
    parser.add_argument('--batch_size', default=16, type=int)
    args = parser.parse_args()
    main(args)

I created 4 model:
Resnet → Resnet34
Customized → Some customised model
MixResnet → Customized + linkage + Resnet
MixResnetReverse → Customized + linkage + Resnet (but have slightly different order in the init function, which the linkage is defined before Resnet)

You can test the four different models with different batch size and number of GPUs.

python mini_model.py --n_gpu 4 --batch_size 16 --arch MixResnet
python mini_model.py --n_gpu 4 --batch_size 16 --arch Customized
python mini_model.py --n_gpu 4 --batch_size 16 --arch Resnet

This three work well

python mini_model.py --n_gpu 4 --batch_size 16 --arch MixResnetReverse

gives mapping error:

Traceback (most recent call last):
  File "mini_model.py", line 170, in <module>
    main(args)
  File "mini_model.py", line 138, in main
    output = model(x)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/usr/local/lib/python3.8/dist-packages/torch/_utils.py", line 425, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/cuda/amp/autocast_mode.py", line 141, in decorate_autocast
    return func(*args, **kwargs)
  File "mini_model.py", line 63, in forward
    out = self.resnet(input_x).unsqueeze(1)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py", line 249, in forward
    return self._forward_impl(x)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py", line 238, in _forward_impl
    x = self.layer2(x)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py", line 71, in forward
    out = self.bn1(out)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/batchnorm.py", line 167, in forward
    return F.batch_norm(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 2281, in batch_norm
    return torch.batch_norm(
RuntimeError: cuDNN error: CUDNN_STATUS_MAPPING_ERROR

While

python mini_model.py --n_gpu 4 --batch_size 4 --arch MixResnetReverse
python mini_model.py --n_gpu 1 --batch_size 4 --arch MixResnetReverse

works ok.

I wonder why simply change the order of two line make such big difference.

Thanks!

I’m not familiar with the einops repository so don’t know how the different orders could create the issue. To further isolate it, you could rerun the code with CUDA_LAUNCH_BLOCKING=1 and see, if the error message would point to another operation, as cuDNN might be running into a previous (sticky) error.
If no error checking is done in the 3rd party libraries, you might also need to synchronize the code manually and e.g. print a CUDATensor to check, if an internal assert was triggered.