Cublas runtime error with output padding in ConvTranspose

I have a model with the following layers


        self.conv1 = nn.Conv3d(1, 3, 3, padding=1)
        
        self.conv2 = nn.Conv3d(3, 128, 3, padding=1, stride=2)
        
        self.conv3 = nn.Conv3d(128, 128, 3, padding=1, stride=2)
        self.conv4 = nn.Conv3d(128, 256, 3, padding=1, stride=1)
        self.conv5 = nn.Conv3d(256, 512, 3, padding=1, stride=2)
        self.conv6 = nn.Conv3d(512, 1024, 3, padding=1, stride=1)
        self.conv7 = nn.Conv3d(1024, 1024, 3, padding=1, stride=1)
        
        
        
        
        
        self.deconv1 = nn.ConvTranspose3d(1024, 1024, 3, 1, padding=1,  output_padding=0 )
        self.deconv2 = nn.ConvTranspose3d(1024, 512, 3, 2, padding=1, output_padding=0)
        self.deconv3 = nn.ConvTranspose3d(512, 256, 3, 2, padding=1, output_padding=0)
        self.deconv4 = nn.ConvTranspose3d(256, 128, 3, 1, padding=1, output_padding=0)
        self.deconv5 = nn.ConvTranspose3d(128, 128, 3, 2, padding=1, output_padding=1)
        self.deconv6 = nn.ConvTranspose3d(128, 2, 3, 1, padding=1, output_padding=0)

On passing tensor at runtime the code produces the following error with cudnn backend if output_padding is set to 1 in deconv5.
RuntimeError: cublas runtime error : library not initialized at /pytorch/torch/lib/THC/THCGeneral.c:405.
Can someone please explain what error might be the reason for this ?

Can you please try running your code with env variable CUDA_LAUNCH_BLOCKING=1 or with cuda-memcheck (cuda-memcheck python my_code.py)? Warning - to run with cuda-memcheck you have to trim down your code to bare minimum, otherwise it would be very slow. cublas error is likely happening because earlier cudnn call returned an error, but it’s hard to say what is happening without additional debug info.

Cuda memcheck returned 128 errors. However on disabling cudnn backend no error is encountered.
It encountered an out of bounds error for some reason.

========= CUDA-MEMCHECK
========= Invalid __global__ write of size 8
=========     at 0x00003710 in void fft3d_r2c_16x16x16<float, float, float2>(float2*, float*, int3, int3, int3, int3, int3, bool)
=========     by thread (15,11,0) in block (235077,0,0)
=========     Address 0x7fbd57036df8 is out of bounds
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame:/usr/lib/x86_64-linux-gnu/libcuda.so.1 (cuLaunchKernel + 0x2cd) [0x22b12d]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/lib/libcudnn-7a90c013.so.7.0.5 [0xcea0eb]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/lib/libcudnn-7a90c013.so.7.0.5 [0xd0733e]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/lib/libcudnn-7a90c013.so.7.0.5 [0x9e4393]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/lib/libcudnn-7a90c013.so.7.0.5 [0x9e5968]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/lib/libcudnn-7a90c013.so.7.0.5 [0x9c7233]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/lib/libcudnn-7a90c013.so.7.0.5 [0x72fa85]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/lib/libcudnn-7a90c013.so.7.0.5 [0x74a01]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/lib/libcudnn-7a90c013.so.7.0.5 (cudnnConvolutionBackwardData + 0x46c) [0x74eec]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so (_ZN5torch5cudnn31cudnn_convolution_backward_dataEP8THCStateP12cudnnContext15cudnnDataType_tPNS_12THVoidTensorES7_S7_PNS0_11ConvolutionEbb + 0x80e) [0x11ad11e]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so (_ZN5torch5cudnn40cudnn_convolution_transpose_full_forwardEP8THCStateP12cudnnContext15cudnnDataType_tPNS_12THVoidTensorES7_S7_S7_St6vectorIiSaIiEESA_SA_ibb + 0x321) [0x11b19c1]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so (_ZN5torch8autograd11ConvForward5applyERKSt6vectorINS0_8VariableESaIS3_EE + 0x1039) [0x522619]
=========     Host Frame:/home/thor/vp36sys/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so [0x410cde]
=========     Host Frame:python (_PyObject_FastCallKeywords + 0x10b) [0xe2a2b]
=========     Host Frame:python [0x178b75]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x3da) [0x17130a]
=========     Host Frame:python [0x17047f]
=========     Host Frame:python [0x179abb]
=========     Host Frame:python [0x178acc]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x3da) [0x17130a]
=========     Host Frame:python [0x17047f]
=========     Host Frame:python (_PyFunction_FastCallDict + 0x440) [0x17a6f0]
=========     Host Frame:python (_PyObject_Call_Prepend + 0x24c) [0xe383c]
=========     Host Frame:python (PyObject_Call + 0x3a) [0xe2f0a]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x1ab5) [0x1729e5]
=========     Host Frame:python [0x17047f]
=========     Host Frame:python (_PyFunction_FastCallDict + 0x1da) [0x17a48a]
=========     Host Frame:python (_PyObject_Call_Prepend + 0x24c) [0xe383c]
=========     Host Frame:python (PyObject_Call + 0x3a) [0xe2f0a]
=========     Host Frame:python [0x1333ee]
=========     Host Frame:python (_PyObject_FastCallKeywords + 0x10b) [0xe2a2b]
=========     Host Frame:python [0x178b75]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x3da) [0x17130a]
=========     Host Frame:python (_PyFunction_FastCallDict + 0x133) [0x17a3e3]
=========     Host Frame:python (_PyObject_Call_Prepend + 0x24c) [0xe383c]
=========     Host Frame:python (PyObject_Call + 0x3a) [0xe2f0a]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x1ab5) [0x1729e5]
=========     Host Frame:python [0x17047f]
=========     Host Frame:python (_PyFunction_FastCallDict + 0x1da) [0x17a48a]
=========     Host Frame:python (_PyObject_Call_Prepend + 0x24c) [0xe383c]
=========     Host Frame:python (PyObject_Call + 0x3a) [0xe2f0a]
=========     Host Frame:python [0x1333ee]
=========     Host Frame:python (_PyObject_FastCallKeywords + 0x10b) [0xe2a2b]
=========     Host Frame:python [0x178b75]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x3da) [0x17130a]
=========     Host Frame:python [0x1799fd]
=========     Host Frame:python [0x178acc]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x3da) [0x17130a]
=========     Host Frame:python [0x1799fd]
=========     Host Frame:python [0x178acc]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x3da) [0x17130a]
=========

Perhaps could it be because the threads generated are more than the number of elements without thread indices being checked for out of bounds access ?

This looks like cudnn bug, can you please isolate which convolution parameters trigger it (input size, kernel size, padding), or provide a small repro script that triggers it? Also, which cudnn version are you using? Thank you!

For the following input the error is generated in the above network at deconv5 at runtime.

    arr = np.random.randn(1, 64, 64, 64)
    input = Variable(torch.from_numpy(arr).float().unsqueeze(1).cuda())
    print(input.shape)
    net = ColorConv().cuda()
    output = net(input)
    print(output.shape)

What card and cudnn version are you using? I could not repro on P100 and V100 with the following script

import torch
import torch.nn as nn
import numpy as np

class MyModel(nn.Module):
    def __init__(self):
       super(MyModel, self).__init__()
       self.conv1 = nn.Conv3d(1, 3, 3, padding=1)
        
       self.conv2 = nn.Conv3d(3, 128, 3, padding=1, stride=2)
        
       self.conv3 = nn.Conv3d(128, 128, 3, padding=1, stride=2)
       self.conv4 = nn.Conv3d(128, 256, 3, padding=1, stride=1)
       self.conv5 = nn.Conv3d(256, 512, 3, padding=1, stride=2)
       self.conv6 = nn.Conv3d(512, 1024, 3, padding=1, stride=1)
       self.conv7 = nn.Conv3d(1024, 1024, 3, padding=1, stride=1)
        
        
        
        
        
       self.deconv1 = nn.ConvTranspose3d(1024, 1024, 3, 1, padding=1,  output_padding=0 )
       self.deconv2 = nn.ConvTranspose3d(1024, 512, 3, 2, padding=1, output_padding=0)
       self.deconv3 = nn.ConvTranspose3d(512, 256, 3, 2, padding=1, output_padding=0)
       self.deconv4 = nn.ConvTranspose3d(256, 128, 3, 1, padding=1, output_padding=0)
       self.deconv5 = nn.ConvTranspose3d(128, 128, 3, 2, padding=1, output_padding=1)
       self.deconv6 = nn.ConvTranspose3d(128, 2, 3, 1, padding=1, output_padding=0)

    def forward(self,x):
       x = self.conv1(x)
       x = self.conv2(x)
       x = self.conv3(x)
       x = self.conv4(x)
       x = self.conv5(x)
       x = self.conv6(x)
       x = self.conv7(x)
       print(x.size())
       x = self.deconv1(x)
       x = self.deconv2(x)
       x = self.deconv3(x)
       x = self.deconv4(x)
       x = self.deconv5(x)
       x = self.deconv6(x)
       return x



model = MyModel().cuda()
arr = np.random.randn(1, 64, 64, 64)
input = (torch.from_numpy(arr).float().unsqueeze(1).cuda())
print(input.shape)
output = model(input)
torch.cuda.synchronize()
print(output.shape)

Gtx 1070, cudnn 7101. Another key-point is that i encounter this error is in pytorch 0.3.1 with your script after passing the input tensor wrapped in autograd.Variable .