Segfault during backward when using PReLU

I’m trying to train a network using PReLU module, but I get Segfault during the backward. Here’s a piece of code that reproduces the bug:

gt = torch.rand(2,3,256,256)
gt = torch.autograd.Variable(gt.cuda(async=True))
input = torch.rand(2,134,256,256)
input = torch.autograd.Variable(input.cuda())
lossL1 = torch.nn.L1Loss()
lossL1 = lossL1.cuda()

net = nn.Sequential(nn.PReLU(), nn.Conv2d(134, 3, kernel_size=1, stride=1, bias=False)).cuda()

output = net(input)

loss = lossL1(output, gt)
loss.backward()

In this example, my network just consists in a PReLU followed by a simple convolution. Note that if I switch the order of both modules, the Segfault doesn’t occur, so it only bugs when the PReLU is the first layer.

Also note that if I don’t use the GPU, the Segfault doesn’t occur neither.

Rem: I tried with pytorch versions 0.4.0 and 0.4.1.

I could reproduce it also on 0.5.0a0+2c7c12f.

Here is the backtrace:

#0 0x00007fffd42ce067 in THCTensor_nElement () from /home/pbialecki/libs/ptrblck/pytorch/torch/lib/libcaffe2_gpu.so
#1 0x00007fffd3cbaefa in bool THC_pointwiseApply3<float, float, float, THTensor, THTensor, THTensor, PReLUAccGradParametersShared >(THCState*, THTensor*, THTensor*, THTensor*, PReLUAccGradParametersShared const&, TensorArgType, TensorArgType, TensorArgType) ()
from /home/pbialecki/libs/ptrblck/pytorch/torch/lib/libcaffe2_gpu.so
#2 0x00007fffd3c99692 in THNN_CudaPReLU_accGradParameters () from /home/pbialecki/libs/ptrblck/pytorch/torch/lib/libcaffe2_gpu.so
#3 0x00007fffd41f2018 in at::CUDAFloatType::prelu_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, std::array<bool, 2ul>) const ()
from /home/pbialecki/libs/ptrblck/pytorch/torch/lib/libcaffe2_gpu.so
#4 0x00007fffd1f47b7d in torch::autograd::VariableType::prelu_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, std::array<bool, 2ul>) const ()
from /home/pbialecki/libs/ptrblck/pytorch/torch/lib/libtorch.so.1
#5 0x00007fffd1e4fa93 in torch::autograd::generated::PreluBackward::apply(std::vector<torch::autograd::Variable, std::allocatortorch::autograd::Variable >&&) ()
from /home/pbialecki/libs/ptrblck/pytorch/torch/lib/libtorch.so.1
#6 0x00007fffd1e0bfcb in torch::autograd::Function::operator()(std::vector<torch::autograd::Variable, std::allocatortorch::autograd::Variable >&&) ()
from /home/pbialecki/libs/ptrblck/pytorch/torch/lib/libtorch.so.1
#7 0x00007fffd1e07291 in torch::autograd::Engine::evaluate_function(torch::autograd::FunctionTask&) ()
from /home/pbialecki/libs/ptrblck/pytorch/torch/lib/libtorch.so.1
#8 0x00007fffd1e07d8b in torch::autograd::Engine::thread_main(torch::autograd::GraphTask*) () from /home/pbialecki/libs/ptrblck/pytorch/torch/lib/libtorch.so.1
#9 0x00007fffd1e045b4 in torch::autograd::Engine::thread_init(int) () from /home/pbialecki/libs/ptrblck/pytorch/torch/lib/libtorch.so.1
#10 0x00007fffe41c5a2a in torch::autograd::python::PythonEngine::thread_init (this=0x7fffe4a88200 , device=0) at torch/csrc/autograd/python_engine.cpp:39
#11 0x00007fffd1651c5c in std::execute_native_thread_routine_compat (__p=)
at /opt/conda/conda-bld/compilers_linux-64_1520532893746/work/.build/src/gcc-7.2.0/libstdc+±v3/src/c++11/thread.cc:110
#12 0x00007ffff7bc16ba in start_thread (arg=0x7fff8df82700) at pthread_create.c:333
#13 0x00007ffff78f741d in clone () at …/sysdeps/unix/sysv/linux/x86_64/clone.S:109

I tried to debug it a bit and it seems the error is thrown if the input does not require gradients.
Code to reproduce the bug:

act = nn.PReLU().to('cuda')
x = torch.randn(1, requires_grad=False, device='cuda')
output = act(x)
output.mean().backward()
print(x.grad)

Setting requires_grad=True for x works.

@tommm994 Could you open a gihub issue and link to this thread?
If you are busy, let me know and I can do it.

1 Like

Ok thanks @ptrblck ! I had already opened an issue on github. I’ve now linked it to this thread.

Btw, with pytorch 0.3.1, the bug does not occur.