CUDNN_STATUS_NOT_SUPPORTED when increasing batch size (PS : this is not an OOM error)

I am facing an issue with a very basic pytorch block on a machine with a A100 GPU with plenty of RAM (80 Gb of RAM)

pytorch                   1.12.0          py3.8_cuda11.3_cudnn8.3.2_0  
torchvision               0.13.0               py38_cu113 
cudatoolkit               11.3.1

I’m suspecting it could be an issue with cuDNN and batchnorm ?
I’m using the following small script to reproduce the problem

import argparse
import torch
import torch.nn as nn
import os
print(f'cuda is available : {torch.cuda.is_available()}')

def main(parsed):
    if parsed.cudnn_off:
        torch.backends.cudnn.enabled = False
        print(f'cuDNN is OFF !')
    else:
        print(f'cuDNN is ON !')
    
        
    x = torch.rand(parsed.batch_size,1,300,300).cuda()

    conv = nn.Conv2d(1, 128, 7,bias=False).cuda()
    bn = nn.BatchNorm2d(128).cuda()

    x = conv(x)
    print(x.shape)

    x = bn(x)
    print(x.shape)

if __name__ == '__main__':
    args = argparse.ArgumentParser()
    args.add_argument('-bs','--batch_size',default=100,type=int,help='batch size')
    args.add_argument('-disable_cudnn','--cudnn_off',action='store_true')
     
    
    parsed = args.parse_args()

    if len(os.sys.argv) == 1:
        args.print_help()
        os.sys.exit()

    main(parsed)

As you can tell, this is not really rocket science, just a convolution followed by a batchnorm

If cuDNN is activated ==> NOT WORKING

(ecgtraining) [HydraPulseML]$ python test_batchnorm.py -bs 256
cuda is available : True
cuDNN is ON !
torch.Size([256, 128, 294, 294])
Traceback (most recent call last):
  File "test_batchnorm.py", line 38, in <module>
    main(parsed)
  File "test_batchnorm.py", line 23, in main
    x = bn(x)
  File "/nics/b/home/doursand/anaconda3/envs/ecgtraining/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/nics/b/home/doursand/anaconda3/envs/ecgtraining/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "/nics/b/home/doursand/anaconda3/envs/ecgtraining/lib/python3.8/site-packages/torch/nn/functional.py", line 2438, in batch_norm
    return torch.batch_norm(
RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input.

if cuDNN is deactivated ==> WORKING

(ecgtraining) [HydraPulseML]$ python test_batchnorm.py -bs 256 --cudnn_off
cuda is available : True
cuDNN is OFF !
torch.Size([256, 128, 294, 294])
torch.Size([256, 128, 294, 294])

Any help woud be appreciated

I’m facing the exactly same issue! It seems something is wrong with cuDNN and batchNorm

Thanks for reporting the issue. It seems to be failing in the batchnorm layer (even in the latest cuDNN version), so I’ll forward the issue to our cuDNN team and keep you updated about the fix.

2 Likes

Hi @ptrblck I wonder if you had a chance to get in touch with the cuDNN team yet ?

Thanks for the ping. Yes, it seems to be a known cuDNN limitation based on the number of elements, so we would have to disable cuDNN batchnorm for these workloads.
As a workaround you could either disable cuDNN globally via torch.backends.cudnn.enabled = False (which might also slow down other layers) of just this batchnorm layer via:

with torch.backends.cudnn.flags(enabled=False):
    out = bn(x)

Hi all! Just curious, are there any new developments or updates on this? If anyone has the latest info or a rough timeline for a resolution, that’d be really helpful. Thanks for staying on top of this!

Did the workaround fail for you?