CUDA runtime error (719) Unspecified Launch Error during training

I’m trying to recreate something like the CRN in table 1 of this paper (except double the input channels) and I get a CUDA runtime error at seemingly random times during training.

My model and trainer example code:
CRN_torch_model_simple.py

import torch
import torch.nn as nn
import torch.nn.functional as F


input_shape = (1001, 4, 1, 161)
kernel_dim  = (1,3)
stride_size = (1,2)


# Should be the flattened length at the end of the encoder
LSTM_size = 1024

# Increasing by factor of 2 from 16
n_filters = [16*(2**x) for x in range(5)]

primary_act = 'elu'
output_act  = 'elu'


# Bundle ConvT, BatchNorm, and Activation together
def decoder_layer(in_c, out_c, k_dims, s_dims, out_pad=(0,0), act_type='elu'):
    
    if act_type.lower() == 'elu':
        act_fn = nn.ELU
    elif act_type.lower() == 'softmax':
        act_fn = nn.Softmax
    
    dec_out = nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, k_dims, s_dims, output_padding=out_pad),
            nn.BatchNorm2d(out_c),
            act_fn()
    )
    return dec_out


class CRN_model(nn.Module):
    def __init__(self):
        super(CRN_model, self).__init__()

        # Encoder conv layers
        self.conv2d_0 = nn.Conv2d(input_shape[1], n_filters[0], kernel_dim, stride_size)
        self.conv2d_1 = nn.Conv2d(n_filters[0], n_filters[1], kernel_dim, stride_size)
        self.conv2d_2 = nn.Conv2d(n_filters[1], n_filters[2], kernel_dim, stride_size)
        self.conv2d_3 = nn.Conv2d(n_filters[2], n_filters[3], kernel_dim, stride_size)
        self.conv2d_4 = nn.Conv2d(n_filters[3], n_filters[4], kernel_dim, stride_size)
        
        # Encoder Batch Norm layers
        self.BN_enc_0 = nn.BatchNorm2d(n_filters[0])
        self.BN_enc_1 = nn.BatchNorm2d(n_filters[1])
        self.BN_enc_2 = nn.BatchNorm2d(n_filters[2])
        self.BN_enc_3 = nn.BatchNorm2d(n_filters[3])
        self.BN_enc_4 = nn.BatchNorm2d(n_filters[4])

        self.LSTM_0 = nn.LSTM(LSTM_size, LSTM_size, 2)

        self.dec_RE_0 = decoder_layer(2*n_filters[4], n_filters[3], kernel_dim,
                                             stride_size,  act_type=primary_act)
        self.dec_IM_0 = decoder_layer(2*n_filters[4], n_filters[3], kernel_dim,
                                             stride_size,  act_type=primary_act)


        self.dec_RE_1 = decoder_layer(2*n_filters[3], n_filters[2], kernel_dim,
                                             stride_size, act_type=primary_act)
        self.dec_IM_1 = decoder_layer(2*n_filters[3], n_filters[2], kernel_dim,
                                             stride_size, act_type=primary_act)

        self.dec_RE_2 = decoder_layer(2*n_filters[2], n_filters[1], kernel_dim,
                                             stride_size, act_type=primary_act)
        self.dec_IM_2 = decoder_layer(2*n_filters[2], n_filters[1], kernel_dim,
                                             stride_size, act_type=primary_act)

        self.dec_RE_3 = decoder_layer(2*n_filters[1], n_filters[0], kernel_dim,
                                             stride_size, out_pad=(0,1), act_type=primary_act)
        self.dec_IM_3 = decoder_layer(2*n_filters[1], n_filters[0], kernel_dim,
                                             stride_size, out_pad=(0,1), act_type=primary_act)

        self.dec_RE_4 = decoder_layer(2*n_filters[0], 1, kernel_dim,
                                             stride_size, act_type=output_act)
        self.dec_IM_4 = decoder_layer(2*n_filters[0], 1, kernel_dim,
                                             stride_size, act_type=output_act)



    def forward(self, x):
         
        # LSTM hidden stuff
        hn = torch.randn((2, 1, LSTM_size), device='cuda', requires_grad=True)
        cn = torch.randn((2, 1, LSTM_size), device='cuda', requires_grad=True)
        
########## ENCODER LAYERS #################################################
        conv_0    = self.conv2d_0(x)
        enc_BN_0  = self.BN_enc_0(conv_0)
        enc_act_0 = F.elu(enc_BN_0)

        conv_1    = self.conv2d_1(enc_act_0)
        enc_BN_1  = self.BN_enc_1(conv_1)
        enc_act_1 = F.elu(enc_BN_1)

        conv_2    = self.conv2d_2(enc_act_1)
        enc_BN_2  = self.BN_enc_2(conv_2)
        enc_act_2 = F.elu(enc_BN_2)

        conv_3    = self.conv2d_3(enc_act_2)
        enc_BN_3  = self.BN_enc_3(conv_3)
        enc_act_3 = F.elu(enc_BN_3)

        conv_4    = self.conv2d_4(enc_act_3)
        enc_BN_4  = self.BN_enc_4(conv_4)
        enc_act_4 = F.elu(enc_BN_4)
        
########## RECURRENT LAYERS ################################################

        enc_out_shape = enc_act_4.shape

        LSTM_in = enc_act_4.view(-1, 1, LSTM_size)
        
        LSTM_out, _ = self.LSTM_0(LSTM_in, (hn, cn))
        LSTM_out    = LSTM_out.view(enc_out_shape)

######### DECODER LAYERS ###################################################
        # copy inputs for two branches
        dec_RE_in = LSTM_out.clone()
        dec_IM_in = LSTM_out.clone()

        # REAL COEFFICIENTS BRANCH
        dec_in  = torch.cat([dec_RE_in, conv_4], dim=1)
        dec_out = self.dec_RE_0(dec_in)
        
        dec_in = torch.cat([dec_out, conv_3], dim=1)
        dec_out = self.dec_RE_1(dec_in)
        
        dec_in = torch.cat([dec_out, conv_2], dim=1)
        dec_out = self.dec_RE_2(dec_in)
        
        dec_in = torch.cat([dec_out, conv_1], dim=1)
        dec_out = self.dec_RE_3(dec_in)
        
        dec_in = torch.cat([dec_out, conv_0], dim=1)
        dec_out_RE = self.dec_RE_4(dec_in)
        
        #IMAG COEFFICIENTS BRANCH
        dec_in = torch.cat([dec_IM_in, conv_4], dim=1)
        dec_out = self.dec_IM_0(dec_in)
        
        dec_in = torch.cat([dec_out, conv_3], dim=1)
        dec_out = self.dec_IM_1(dec_in)
        
        dec_in = torch.cat([dec_out, conv_2], dim=1)
        dec_out = self.dec_IM_2(dec_in)
        
        dec_in = torch.cat([dec_out, conv_1], dim=1)
        dec_out = self.dec_IM_3(dec_in)
        
        dec_in = torch.cat([dec_out, conv_0], dim=1)
        dec_out_IM = self.dec_IM_4(dec_in)       
               
####### OUTPUT #####################################################################
        out = torch.cat((dec_out_RE, dec_out_IM), dim=1)

        return out


simplified_training_ex.py

import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch

# My model
from CRN_torch_model_simple import CRN_model 

# Simple Dataset class to load batches from numpy matrix
from CRN_dataset_stream import CRN_Dataset 

from torch.utils.data import DataLoader


def main():
    
    device = torch.device("cuda")
    
    top_dir = os.path.dirname(os.path.realpath(__file__))
    
    dataset = CRN_Dataset(top_dir)
    training_loader = DataLoader(dataset)
    
    model = CRN_model().to(device)
    
    print("START TRAINING") 
    for k, data in enumerate(training_loader):

        # Each batch (T_frames, N_channels, 1, N_bins)
        # (1001, 4, 1, 161)
        
        # Get rid of extra dimension in 1st position. Type cast. 
        x_in   = data[0].squeeze(dim=0).type(torch.float)
        x_in = x_in.cuda()

        model.zero_grad()
        
        _ = model(x_in)
        
        if k%10==0:
            print(f"{k} batches done")
            print(f"GPU memory allocated : {torch.cuda.memory_allocated()/1024**3}")

                
if __name__ == "__main__":  
    main()

In the example code I’m not even training the model, as it is not necessary to demonstrate the error. All I’m doing is forward propagating a bunch of shape (1001,4,1,161) cuda float tensors through the model.

This is the error I get every time.

THCudaCheck FAIL file=..\aten\src\THC\THCCachingHostAllocator.cpp line=278 error=719 : unspecified launch failure
Traceback (most recent call last):
  File ".\simplified_training_ex.py", line 46, in <module>
    main()
  File ".\simplified_training_ex.py", line 38, in main
    _ = model(x_in)
  File "C:\Users\silva\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "I:\CRN_model\CRN_torch_model_simple.py", line 144, in forward
    dec_in = torch.cat([dec_out, conv_0], dim=1)
RuntimeError: cuda runtime error (719) : unspecified launch failure at ..\aten\src\THC\THCCachingHostAllocator.cpp:278

When I uncomment the line:

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

I don’t get that error anymore but it’s much slower to iterate through batches (which is expected). I’d like to know why I keep getting that CUDA runtime error in non-blocking mode. Is there any way I could get around it?

Environment:

  • Windows 10 Home
  • Ryzen 7 3700
  • 128GB RAM
  • 2080TI (456.55)
  • Dist: Conda 4.8.5
  • Python 3.8.3
  • torch version 1.6.0
  • torch CUDA version 10.2

Could you check, if you are potentially running out of memory and disable TDR as suggested here?

I’m not running out of memory. I’ve monitored the GPU mem allocated and it’s no more than 1.7GB. I’ve also disabled TDR and it didn’t change anything.

On top of that I’ve tried it on other PC’s with a couple different recent drivers

  • My other Windows 10 machine, GTX 960 (Driver 456.71 and 436.48), 48GB Corsair memory, Ryzen 1300
  • Friends Windows 10 machine, GTX 960(456.71), 16GB RAM, i5-3570k
  • Friends Windows 10 machine GTX 980TI(456.71), 32GB RAM, Ryzen 2700X

It DID work fine on my other machine when I tried it on my Ubuntu 16 boot.

excuse why this error appeared. Does that because of the type of operating system?