Illegal memory access in backward after first training epoch

I’m using a DataLoader and looping through my training data in the usual way:

for epoch in range(num_epochs):
	for training_batch_idx, training_batch in enumerate(dataloader):
		#  forward/backward propagation code

Everything is fine during the first epoch. In the second epoch, when backward() is called for the first time, I get the following error:

THCudaCheck FAIL file=/data/users/soumith/miniconda2/conda-bld/pytorch-cuda80-0.1.10_1488757768560/work/torch/lib/THCUNN/generic/PReLU.cu line=79 error=77 : an illegal memory access was encountered
Traceback (most recent call last):
  File "trainer.py", line 115, in <module>
    err.backward()
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/torch/autograd/variable.py", line 146, in backward
    self._execution_engine.run_backward((self,), (gradient,), retain_variables)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/torch/nn/_functions/thnn/activation.py", line 53, in backward
    1
RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at /data/users/soumith/miniconda2/conda-bld/pytorch-cuda80-0.1.10_1488757768560/work/torch/lib/THCUNN/generic/PReLU.cu:79

The error points to some PReLU code. However, if I replace all the PReLU layers in my net with ReLU, I still get an illegal memory access error; it just points somewhere else:

THCudaCheck FAIL file=/data/users/soumith/miniconda2/conda-bld/pytorch-cuda80-0.1.10_1488757768560/work/torch/lib/THC/generic/THCTensorMath.cu line=26 error=77 : an illegal memory access was encountered
Traceback (most recent call last):
  File "trainer.py", line 115, in <module>
    err.backward()
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/torch/autograd/variable.py", line 146, in backward
    self._execution_engine.run_backward((self,), (gradient,), retain_variables)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/torch/nn/_functions/batchnorm.py", line 60, in backward
    grad_bias = bias.new(bias.size()).zero_()
RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at /data/users/soumith/miniconda2/conda-bld/pytorch-cuda80-0.1.10_1488757768560/work/torch/lib/THC/generic/THCTensorMath.cu:26

Any thoughts on what might cause an error like this?

Is there any chance we could see the model?

Below is a minimal example that produces an illegal memory access error. The really frustrating thing is that seemingly unimportant modifications to the network remove the error. For instance, changing the number of planes in the hidden layer from 8 to 16 removes the error. Also, deleting the ReLU layer removes the error.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data

#---------------------------------------------------------------------------
# define encoder-decoder network and optimizer

class encoder_decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv3d(1,8, kernel_size=2, stride=2)
        self.deconv = nn.ConvTranspose3d(8,1, kernel_size=3, stride=2, output_padding=-1)

    def forward(self, net_input):
        out = net_input
        out = self.conv(out)
        out = self.deconv(out)
        out = nn.ReLU()(out)
        return out

net = encoder_decoder()
net.cuda()

criterion = nn.MSELoss()
criterion.cuda()

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.5)

#----------------------------------------------------------------------------
# define dataset and dataloader

class create_dataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data = [ (torch.rand(1,64,64,64),
                       torch.rand(1,64,64,64)) for i in range(100) ]

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

dataset = create_dataset()
print('Loaded ' + str(len(dataset)) + ' training examples')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, num_workers=8)

#----------------------------------------------------------------------------
# training loop

print('Start training loop')
for epoch in range(4):
	
    print('Epoch: ' + str(epoch))
    net.train()

    for training_idx, (input_batch,target_batch) in enumerate(dataloader):
        print('Training batch: ' + str(training_idx))
        input_batch = Variable(input_batch.cuda())
        target_batch = Variable(target_batch.cuda())

        optimizer.zero_grad()
        output_batch = net(input_batch)
        err = criterion(output_batch, target_batch)
        err.backward()
        optimizer.step()
1 Like

The real stacktrace (using CUDA_LAUNCH_BLOCKING=1) is:

Loaded 100 training examples
Start training loop
Epoch: 0
Training batch: 0
Traceback (most recent call last):
  File "tmp.py", line 70, in <module>
    err.backward()
  File "/usr/local/lib/python2.7/dist-packages/torch/autograd/variable.py", line 146, in backward
    self._execution_engine.run_backward((self,), (gradient,), retain_variables)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/_functions/conv.py", line 48, in backward
    if self.needs_input_grad[0] else None)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/_functions/conv.py", line 112, in _grad_input
    cudnn.benchmark)
RuntimeError: CUDNN_STATUS_EXECUTION_FAILED

Temp fix: Adding torch.backends.cudnn.enabled = False at the beginning of your script make it work as a temporary fix.

Real Problem: Why is this cudnn call fails? I don’t know :confused:
EDIT: this exact code that fails on a Titan X runs properly on a Titan Black.

1 Like

I’ve tested the minimal example on three GPUs: Maxwell Titan X, GTX 1080, and Pascal Titan X. It fails on all three of them. The 2d version of this encoder-decoder works just fine, but it would be nice to get it working for the 3d case as well. Note, this is a very common network construction (just minus the hidden layers), so anyone segmenting 3d data is going to run into this.

For the time being, I’ll just disable cudnn, as suggested.

I’ve reproed with cudnn 5.1.10 (CUDNN_EXECUTION_FAILED). On cudnn 6.0.5 it returns

Traceback (most recent call last):
  File "conv3d.py", line 64, in <module>
    output_batch = net(input_batch)
  File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/modules/module.py", line 202, in __call__
    result = self.forward(*input, **kwargs)
  File "conv3d.py", line 19, in forward
    out = self.deconv(out)
  File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/modules/module.py", line 202, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/modules/conv.py", line 613, in forward
    output_padding, self.groups)
  File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/functional.py", line 141, in conv_transpose3d
    return f(input, weight, bias)
RuntimeError: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input.

@albanD, I take it the input must be contiguous, so that should be something else.

Interesting. When I call backward on a 3D convolution that’s supposed to perform the same operation as ConvTranspose, it is Ok, and on ConvTranspose it errors out. Should be same parameters to the same cudnnConvolutionBackwardData call.

import torch
import torch.nn as nn
from torch.autograd import Variable

conv = nn.Conv3d(1,8, kernel_size=3, padding=1, stride = 2)
conv = conv.cuda()
deconv = nn.ConvTranspose3d(8,1, kernel_size=3, stride=2, output_padding=-1)
deconv = deconv.cuda()

x = Variable(torch.randn(8,1,64,64,64).cuda(), requires_grad=True)
#ok
for i in range(10):
    out = conv(x)
    err = out.sum()
    err.backward()

print("deconv")

#error
x = Variable(torch.randn(8,8,32,32,32).cuda(), requires_grad=True)

for i in range(10):
    out = deconv(x)
    err = out.sum()
    err.backward()

Looks like its a bug in pytorch that ignores output_padding argument (should be translated into padding for convolution descriptor). @colesbury ^^. Convolution descriptor that’s being passed into failing cudnnGetConvolutionBackwardDataAlgorithm call has 0 padding, whereas it should have 1 (here’s ltrace for the failing call, followed by ltrace for the passing one for the gradient wrt data for the companion forward conv, they should be the same:

[pid 21910] _C.cpython-35m-x86_64-linux-gnu.so->cudnnGetConvolutionBackwardDataAlgorithm(0xffff7fffffff, { CUDNN_DATA_FLOAT, 0, [ 5, 8, 1, 3... ], CUDNN_TENSOR_NCHW }, { CUDNN_DATA_FLOAT, 0, 5, [ 2097152, 0, 8, 8... ], [ 0, 0, 262144, 32768... ] }, { CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT, 3, [ 0, 0, 0, 0... ], [ 2, 2, 2, 0... ], [ 1, 1, 1, 0... ] }, { CUDNN_DATA_FLOAT, 0, 5, [ 2097152, 0, 8, 1... ], [ 0, 0, 262144, 262144... ] }, 1, 0, 0 <unfinished ...>
[pid 21917] _C.cpython-35m-x86_64-linux-gnu.so->cudnnGetConvolutionBackwardDataAlgorithm(0xffff7fffffff, { CUDNN_DATA_FLOAT, 0, [ 5, 8, 1, 3... ], CUDNN_TENSOR_NCHW }, { CUDNN_DATA_FLOAT, 0, 5, [ 2097152, 0, 8, 8... ], [ 0, 0, 262144, 32768... ] }, { CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT, 3, [ 1, 1, 1, 0... ], [ 2, 2, 2, 0... ], [ 1, 1, 1, 0... ] }, { CUDNN_DATA_FLOAT, 0, 5, [ 2097152, 0, 8, 1... ], [ 0, 0, 262144, 262144... ] }, 1, 0, 0 <unfinished ...>

It is still a mystery to me how it can work anywhere (on Kepler? With 16 planes? deleting ReLU?). Also would probably be a good idea to add deconv test with output_padding, probably there are none now.

2 Likes

To close the loop on this, here’s (my best idea of) what’s happening. Cudnn until v6 did not strictly control the dimensions of the output tensor that is passed to the convolution routines. Output dimensions were supposed to be calculated using provided function, user passing output dimensions that differ from expected could result in undefined behavior (which is what you were seeing - sometimes it would appear to run normally, probably producing garbage, sometimes it would result in illegal memory accesses). In cudnn v6, if passed output dimensions are bigger than what cudnn expects, it would return STATUS_NOT_SUPPORTED, and it would respect output dimensions passed by user if they are not bigger than what it expects. When negative output_padding is used, it results in output dimensions passed that are bigger than what cudnn expects. I’ve disabled using cudnn for negative output_padding cases in https://github.com/pytorch/pytorch/pull/996, but gradchecker tests still don’t pass for negative output_padding, and I don’t exactly know what pytorch backends compute in this case.

2 Likes

Is the problem solved? I am getting the same error when using Conv3D

RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at …/pytorch/torch/lib/THCUNN/generic/Threshold.cu:66

After updating Pytorch I am getting:

RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at …/pytorch/torch/lib/THCUNN/generic/Threshold.cu:66

This error pops up not in the first epoch but fourth epoch for me.

2 Likes

I’ve just seen this too. A model that had been working gives an error after 17 epochs:

  File "/home/jhoward/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py", line 776, in binary_cross_entropy
    return _functions.thnn.BCELoss.apply(input, target, weight, size_average)
  File "/home/jhoward/anaconda3/lib/python3.6/site-packages/torch/nn/_functions/thnn/auto.py", line 47, in forward
    output, *ctx.additional_args)
RuntimeError: cudaEventSynchronize in future::wait: an illegal memory access was encountered

This is with the torchvision resnet34 model, with 244x244 input, and AdaptiveAveragePooling2D before the linear layer. Let me know if there’s any more info I can provide.

Pytorch is current version from conda as of yesterday. Python 3.6. All conda packages updated.

1 Like

@jphoward just so that i can reproduce this:

  • is this with imagenet? if not, how big is your dataset?
  • what GPU were you using?
  • what CUDA version did you install? did you use conda install pytorch torchvision cuda80 -c soumith ?

So sorry @smth it was quite a while ago and I haven’t seen it again since, so I can’t even replicate myself any more…

Got this error again after about 110K steps using nn.DataParallel. Following is the system config:

cuda: 9.1
pytorch: 0.3.0.post4 (conda install pytorch torchvision cuda90 -c pytorch)
Nvidia driver version: 387.26
GPU: 2x 1080ti

I was having this same issue but on a tabular dataset using the fastai library on top of pytorch and couldn’t figure it out. Ultimately I fixed the issue by running CUDA_LAUNCH_BLOCKING=1 to get the real stack trace. Then I saw pytorch was trying to save to a tmp directory that was empty but for some reason caused the illegal memory access error, possibly because the directory was being used by something else in the background? Anyway deleting the empty directory fixed the issue.

CUDA error after cudaEventDestroy in future dtor: an illegal memory access was encountered

line 1483, in binary_cross_entropy
return torch._C._nn.binary_cross_entropy(input, target, weight, size_average, reduce)

RuntimeError: cudaEventSynchronize in future::wait: an illegal memory access was encountered

I encountered this error while training my model. I am using pytorch version 0.4.0, cuda version 9.0 and cudnn version 9.70. The only way currently for me to overcome this problem is to set

torch.backends.cudnn.enabled=False

If I disable cudnn, I observe the training takes more time than usual. Is there are fix to this issue?

Hi,
I hope its not too late. I also have encountered the same issue while extracting C3D features for my work. I managed to solve it by adding this line
net.volatile=True
probably you could also try something similar and make it work keeping your cudNN enabled.

Solved a similar problem by setting cudnn.benchmark=False, but loses some speed.

That’s not a proper solution.
Could you update to the latest stable PyTorch version and post a (minimal) executable code snippet so that we could take a look at it?