Sudden exit on relatively small model gpu

this is one of my first times using pytorch so it’s likely something small and stupid .
I have a very small model with 2 conv1d layers. when I make their filten len small enough the model trains. when I increase it the model fails.
I have a gtx1070 8gb gpu that I have ran larger models (order of magnitude) on with the same data (in keras).
The following is the error I get

Process finished with exit code -1073740940 (0xC0000374)

This is a dummy test for the code:

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np

class Net(nn.Module):
    def __init__(self, in_channels, out_channels, filter_len):
        super(Net, self).__init__()
        self.conv_lin = nn.Conv1d(in_channels, out_channels, filter_len, padding=filter_len - 1)
        nn.init.xavier_uniform(self.conv_lin.weight)
        self.conv_sig = nn.Conv1d(in_channels, out_channels, filter_len, padding=filter_len - 1)
        nn.init.xavier_uniform(self.conv_sig.weight)

    def forward(self, x):
        x_lin = self.conv_lin(x)
        x_sig = torch.sigmoid(self.conv_sig(x))
        return F.log_softmax(x_lin * x_sig, dim=1)



# Define train and test data
batch_size = 32
train_loader = None  # Change this to training data iterator
test_loader = None  # Change this to testing data iterator


# Checking GPU availability
use_gpu = True #torch.cuda.is_available()


model = Net(13, 25, 40) # Process finished with exit code -1073740940 (0xC0000374)
model = Net(13, 25, 9) # works
if use_gpu:
    model = model.cuda()

optimizer = torch.optim.SGD(model.parameters(), 0.01, 0.9)



model.train()
criterion = nn.CTCLoss(blank=24, reduction='mean')
for batch_idx in range(100):

    mfcc = Variable(torch.from_numpy(np.random.randn(5, 13, 200)).cuda())
    input_lengths = Variable(torch.from_numpy(np.random.randint(100, 200, size=(mfcc.shape[0]))).cuda())
    output_lengths = Variable(torch.from_numpy(np.random.randint(50, 75, size=(mfcc.shape[0]))).cuda())
    labels = Variable(torch.from_numpy(np.random.randint(0, 24, size=(output_lengths.sum()))).cuda())


    optimizer.zero_grad()
    output = model(mfcc.float()).transpose(1, 2).transpose(0, 1)
    loss = criterion(output, labels, input_lengths, output_lengths)

    loss.backward()
    print("hi")
    optimizer.step()

does anyone have an idea why this is falling?

Do you get any kind of error, when it “dies” after 2 batches?
Since the model seems to be quite small I doubt you are running out of memory (of course if the GPU wasn’t filled by other processes).

I edited my post with example code.
the error I get is :
Process finished with exit code -1073740940 (0xC0000374)
I get something similar when I get an OOM in keras sometimes which leads me to believe it’s something to do with memory…

0xc0000374 is most likely a heap corruption error on Windows.
Could you run your code with a debugger and inspect the stack trace?

thank you for the incredibly fast reply.
I am running it in debug. it exists abruptly. I have mapped it out to the

loss.backward()

Oh I haven’t seen you’ve added the code.
I’ll try to reproduce it on my machine.

EDIT:
I can reproduce the error with the exception: free(): invalid next size (normal) in Variable._execution_engine.run_backward().

EDIT2:
The code runs fine on CPU, so it seems to be related to some CUDA functions.

I wonder why I don’t get that debug output… are you on linux?

I forgot to mention that - I started from cpu only and everything was running fine but very very slowly…

is this something meaningful?:
free(): invalid next size (normal) in Variable._execution_engine.run_backward() .

I’m not sure and digging a bit further. :wink:
Yes, I’m on Ubuntu 18.04, so probably that’s why the error messages are different.

Thank you !
If I already have your attention. I saw that there is .cuda() and .to(device) . What is the best practice?

I would use to(device), since it’s easier to write device agnostic code.
In fact, I removed all cuda() calls from your code on my machine and replaced it with:

device = 'cpu'
#device = 'cuda:0'

model = model.to(device)

to easily switch between CPU and GPU runs.

Here is the backtrace of gdb:

#0  __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:51
#1  0x00007ffff7805801 in __GI_abort () at abort.c:79
#2  0x00007ffff784e897 in __libc_message (action=action@entry=do_abort, 
    fmt=fmt@entry=0x7ffff797bb9a "%s\n") at ../sysdeps/posix/libc_fatal.c:181
#3  0x00007ffff785590a in malloc_printerr (
    str=str@entry=0x7ffff797d8b8 "free(): invalid next size (normal)") at malloc.c:5350
#4  0x00007ffff785d0ad in _int_free (have_lock=0, p=0x7fff58005310, av=0x7fff58000020)
    at malloc.c:4286
#5  __GI___libc_free (mem=0x7fff58005320) at malloc.c:3124
#6  0x00007fffa9cd5b72 in cudnn::maxwell::gemm::conv2d(cudnnContext*, void const*, cudnnTensor4dStruct*, void const*, cudnnFilter4dStruct*, void const*, cudnnConvolutionStruct*, cudnnConvWorkingStruct const*, void*, unsigned long, void const*, cudnnTensor4dStruct*, void*, cudnn::maxwell::gemm::Conv2dType_t, cudnn::maxwell::gemm::Conv2dConfig&, bool, void const*, cudnnActivationStruct*, void*) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so
#7  0x00007fffa9d4ad0b in cudnn::wgrad2d::invokeFfmaKernel(cudnnContext*, void const*, cudnnTensor4dStruct*, void const*, cudnnTensor4dStruct*, void const*, cudnnConvolutionStruct*, cudnnConvWorkingStruct const*, cudnnConvolutionBwdFilterAlgo_t, void*, unsigned long, void const*, cudnnFilter4dStruct*, void*, cudnnStatus_t*) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so
#8  0x00007fffa9d4f96e in cudnnConvolution4dBackwardFilter(cudnnContext*, void const*, cudnnTensor4dStruct*, void const*, cudnnTensor4dStruct*, void const*, cudnnConvolutionStruct*, cudnnConvWorkingStruct const*, cudnnConvolutionBwdFilterAlgo_t, void*, unsigned long, void const*, cudnnFilter4dStruct*, void*) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so
#9  0x00007fffa9a2435c in cudnnConvolutionBackwardFilterInternal(cudnnContext*, void const*, cudnnTensorStruct*, void const*, cudnnTensorStruct*, void const*, cudnnConvolutionStruct*, cudnnConvolutionBwdFilterAlgo_t, void*, unsigned long, void const*, cudnnFilterStruct*, void*) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so
#10 0x00007fffa9a24b08 in cudnnConvolutionBackwardFilter ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so
#11 0x00007fffa7306da4 in at::native::raw_cudnn_convolution_backward_weight_out(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so
#12 0x00007fffa73074f7 in at::native::cudnn_convolution_backward_weight(char const*, c10::ArrayRef<long>, at::TensorArg const&, at::TensorArg const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so
#13 0x00007fffa7307827 in at::native::cudnn_convolution_backward_weight(c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so
#14 0x00007fffa73dc66b in at::CUDAFloatType::cudnn_convolution_backward_weight(c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) const ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_g---Type <return> to continue, or q <return> to quit---
pu.so
#15 0x00007fffa73039f2 in at::native::cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so
#16 0x00007fffa73dc832 in at::CUDAFloatType::cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) const ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so
#17 0x00007fffa0e1966a in torch::autograd::VariableType::cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) const ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libtorch.so.1
#18 0x00007fffa0c52b66 in torch::autograd::generated::CudnnConvolutionBackward::apply(std::vector<torch::autograd::Variable, std::allocator<torch::autograd::Variable> >&&) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libtorch.so.1
#19 0x00007fffa0c2869e in torch::autograd::Engine::evaluate_function(torch::autograd::FunctionTask&)
    ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libtorch.so.1
#20 0x00007fffa0c2a770 in torch::autograd::Engine::thread_main(torch::autograd::GraphTask*) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libtorch.so.1
#21 0x00007fffa0c27222 in torch::autograd::Engine::thread_init(int) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libtorch.so.1
#22 0x00007fffe67c54ca in torch::autograd::python::PythonEngine::thread_init(int) ()
   from /home/ptrblck/anaconda3/envs/pytorch_latest/lib/python3.7/site-packages/torch/lib/libtorch_python.so
#23 0x00007fffe7534678 in std::execute_native_thread_routine_compat (__p=<optimized out>)
    at /opt/conda/conda-bld/compilers_linux-64_1534514838838/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/src/c++11/thread.cc:94
#24 0x00007ffff7bbd6db in start_thread (arg=0x7fff90ffd700) at pthread_create.c:463
#25 0x00007ffff78e688f in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95

I’m not sure what’s going on and would like to invite some experts on this topic.
CC @colesbury, @albanD
Any ideas how to debug this issue further?

EDIT:
It looks like this issue is related to cuDNN.
@Dan_Erez
You could set torch.backends.cudnn.enabled = False at the beginning of your script for now to avoid this problem.

1 Like

Hmmm… I’m not sure what the cause is. Thanks for debugging it @ptrblck. I filed an issue here: https://github.com/pytorch/pytorch/issues/17060

1 Like

I figured it out !!!
it’s the padding. I realized what i actually wanted was half the filter length from each direction. it now runs fine ( even for a larger network ).

class Net(nn.Module):
    def __init__(self, in_channels, out_channels, filter_len):
        super(Net, self).__init__()
        self.conv_lin = nn.Conv1d(in_channels, out_channels, filter_len, padding=(filter_len - 1)//2)
        nn.init.xavier_uniform(self.conv_lin.weight)
        self.conv_sig = nn.Conv1d(in_channels, out_channels, filter_len, padding=(filter_len - 1)//2)
        nn.init.xavier_uniform(self.conv_sig.weight)

    def forward(self, x):
        x_lin = self.conv_lin(x)
        x_sig = torch.sigmoid(self.conv_sig(x))
        return x_lin * x_sig

I guess there is some bug involving unexpectedly long padding…

I’ve seen the strange padding and tried to reproduce the error with a “standalone” CNN using your conv layers, but failed to do so.

Good it’s working now, but something is still going on in your original architecture.

yeah, there’s a bug somewhere - could be a problem for causal conv1d.

best of luck

1 Like