UNET model on GPU runtime error

I am building my custom unet model which takes numpy matrix as input and has mask associated with it(0,1) for each pixel. I am converting this to a torch image by expanding dimensions using numpy.expand(npMat, 0) converting it to torch image of the form C x H x W. Here is the model for your reference:

class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = self.contract_block(in_channels, 32, 7, 3)
        self.conv2 = self.contract_block(32, 64, 3, 1)
        self.conv3 = self.contract_block(64, 128, 3, 1)

        self.upconv3 = self.expand_block(128, 64, 3, 1)
        self.upconv2 = self.expand_block(64*2, 32, 3, 1)
        self.upconv1 = self.expand_block(32*2, out_channels, 3, 1)

    def __call__(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        # upsampling part
        upconv3 = self.upconv3(conv3)
        upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))

        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 
                            )
        return expand

Here is the training part:

import time
from IPython.display import clear_output

def train(model, train_dl, valid_dl, loss_fn, optimizer, acc_fn, epochs=1):
    print("Inside train..")
    start = time.time()
    dev = torch.device("gpu")


    train_loss, valid_loss = [], []

    best_acc = 0.0
    print("Starting Epochs........")
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  # Set trainind mode = true
                dataloader = train_dl
            else:
                model.train(False)  # Set model to evaluate mode
                dataloader = valid_dl

            running_loss = 0.0
            running_acc = 0.0

            step = 0

            # iterate over data
            for x, y in dataloader:

                x = x.cuda()
                y = y.cuda()

                step += 1

                # forward pass
                if phase == 'train':
                    optimizer.zero_grad()
                    outputs = model(x)
                    loss = loss_fn(torch.sigmoid(outputs), y)

                    loss.backward()
                    optimizer.step()
                    # scheduler.step()

                else:
                    with torch.no_grad():
                        outputs = model(x)
                        loss = loss_fn(outputs, y)
                        
                acc = acc_fn(outputs, y)

                running_acc  += acc*dataloader.batch_size
                running_loss += loss*dataloader.batch_size 

                if step % 100 == 0:
                    # clear_output(wait=True)
                    print('Current step: {}  Loss: {}  Acc: {}  AllocMem (Mb): {}'.format(step, loss, acc, torch.cuda.memory_allocated()/1024/1024))
                    
                    #print(torch.cpu.memory_summary())

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_acc / len(dataloader.dataset)
            
            clear_output(wait=True)
            print('Epoch {}/{}'.format(epoch, epochs - 1))
            print('-' * 10)
            print('{} Loss: {:.4f} Acc: {}'.format(phase, epoch_loss, epoch_acc))
            print('-' * 10)

            train_loss.append(epoch_loss) if phase=='train' else valid_loss.append(epoch_loss)

    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    
    return train_loss, valid_loss    

def acc_metric(predb, yb):
    return (predb.argmax(dim=1) == yb.cuda()).float().mean()

Complete error is here:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-40-effb400478a6> in <module>()
      2 #output = loss_fn(1, target)
      3 opt = torch.optim.Adam(unet.parameters(), lr=0.01)
----> 4 train_loss, valid_loss = train(unet, train_dl, valid_dl, loss_fn, opt, acc_metric, epochs=50)

<ipython-input-37-5335f67c61f3> in train(model, train_dl, valid_dl, loss_fn, optimizer, acc_fn, epochs)
     43                 if phase == 'train':
     44                     optimizer.zero_grad()
---> 45                     outputs = model(x)
     46                     loss = loss_fn(torch.sigmoid(outputs), y)
     47 

<ipython-input-36-6dd22bc4cdce> in __call__(self, x)
     14 
     15         # downsampling part
---> 16         conv1 = self.conv1(x)
     17         conv2 = self.conv2(conv1)
     18         conv3 = self.conv3(conv2)

/gpfs/share/apps/anaconda3/gpu/5.2.0/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/gpfs/share/apps/anaconda3/gpu/5.2.0/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     90     def forward(self, input):
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input
     94 

/gpfs/share/apps/anaconda3/gpu/5.2.0/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/gpfs/share/apps/anaconda3/gpu/5.2.0/lib/python3.6/site-packages/torch/nn/modules/conv.py in forward(self, input)
    336                             _pair(0), self.dilation, self.groups)
    337         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338                         self.padding, self.dilation, self.groups)
    339 
    340 

RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

This gives me RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED error and if I switch to CPU it takes forever to run and then the kernel timesout. Not sure how to solve it. Any help is appreciated, thanks for your help!

Could you post the output of python -m torch.utils.collect_env?

Collecting environment information...
PyTorch version: 2.0.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Red Hat Enterprise Linux Server release 7.9 (Maipo) (x86_64)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: glibc-2.17

Python version: 3.8.1 | packaged by conda-forge | (default, Jan 29 2020, 14:55:04)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.31.1.el7.x86_64-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 9.0.176
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe

Nvidia driver version: 525.60.13
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                48
On-line CPU(s) list:   0-47
Thread(s) per core:    1
Core(s) per socket:    24
Socket(s):             2
NUMA node(s):          4
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 106
Model name:            Intel(R) Xeon(R) Gold 6342 CPU @ 2.80GHz
Stepping:              6
CPU MHz:               2800.000
BogoMIPS:              5600.00
Virtualization:        VT-x
L1d cache:             48K
L1i cache:             32K
L2 cache:              1280K
L3 cache:              36864K
NUMA node0 CPU(s):     0-11
NUMA node1 CPU(s):     12-23
NUMA node2 CPU(s):     24-35
NUMA node3 CPU(s):     36-47
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 invpcid_single intel_pt ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq md_clear pconfig spec_ctrl intel_stibp flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0
[pip3] torch-summary==1.4.5
[pip3] torch-utils==0.1.2
[pip3] torchvision==0.15.1
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] torch                     2.0.0                    pypi_0    pypi
[conda] torch-summary             1.4.5                    pypi_0    pypi
[conda] torch-utils               0.1.2                    pypi_0    pypi
[conda] torchvision               0.15.1                   pypi_0    pypi

Thanks for the information.
I’m unable to reproduce the issue using:

import torch
import torch.nn as nn


class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = self.contract_block(in_channels, 32, 7, 3)
        self.conv2 = self.contract_block(32, 64, 3, 1)
        self.conv3 = self.contract_block(64, 128, 3, 1)

        self.upconv3 = self.expand_block(128, 64, 3, 1)
        self.upconv2 = self.expand_block(64*2, 32, 3, 1)
        self.upconv1 = self.expand_block(32*2, out_channels, 3, 1)

    def __call__(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        # upsampling part
        upconv3 = self.upconv3(conv3)
        upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))

        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 
                            )
        return expand


model = UNET(1, 1).cuda()
x = torch.randn(1, 1, 224, 224).cuda()

out = model(x)
print(out.shape)

Output:

torch.Size([1, 1, 224, 224])

Your CUDA runtime version looks really old:

CUDA runtime version: 9.0.176

and I don’t know where this is coming from. The PyTorch 2.0.0+cu117 pip wheel ships with the CUDA 11.7 runtime and your driver also is quite new. Did you install any old PyTorch packages into your env?
Btw. was any workload working on this system before?

I get the same error on google colab as well so I don’t think system is the error here. Yes there have been many workload working on this system before. Could’t it be because of the underlying data?

I think it could be related to this? retinanet - Tensorflow 2.1 Failed to get convolution algorithm. This is probably because cuDNN failed to initialize - Stack Overflow

I don’t think your error is related to TensorFlow.
Did you run my code and are you seeing the same error?

I got the same error running your code:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-25-211cce0ed205> in <module>()
     58 x = torch.randn(1, 1, 224, 224).cuda()
     59 
---> 60 out = model(x)
     61 print(out.shape)

<ipython-input-25-211cce0ed205> in __call__(self, x)
     18 
     19         # downsampling part
---> 20         conv1 = self.conv1(x)
     21         conv2 = self.conv2(conv1)
     22         conv3 = self.conv3(conv2)

/gpfs/share/apps/anaconda3/gpu/5.2.0/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/gpfs/share/apps/anaconda3/gpu/5.2.0/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     90     def forward(self, input):
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input
     94 

/gpfs/share/apps/anaconda3/gpu/5.2.0/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/gpfs/share/apps/anaconda3/gpu/5.2.0/lib/python3.6/site-packages/torch/nn/modules/conv.py in forward(self, input)
    336                             _pair(0), self.dilation, self.groups)
    337         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338                         self.padding, self.dilation, self.groups)
    339 
    340 

RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

This issue got resolved by changing the CUDA runtime version to CUDA 11.7 runtime from 9.0.176.