RuntimeError: cuda runtime error (59) : device-side assert triggered

I am trying to train a bidirectional RNN on GPU but there is an error which always occur at the same execution time (approximately 62% of the first epoch).

Here is my model:

class RNNClassifier(nn.Module):

    def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):
        super(RNNClassifier, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.n_directions = int(bidirectional) + 1

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          bidirectional=bidirectional)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, input):
        input = input.t()
        batch_size = input.size(1)

        hidden = self._init_hidden(batch_size)

        embedded = self.embedding(input)

        self.gru.flatten_parameters()
        output, hidden = self.gru(embedded, hidden)

        fc_output = self.fc(hidden[-1])
        return fc_output

    def _init_hidden(self, batch_size):
        hidden = torch.zeros(self.n_layers * self.n_directions,
                             batch_size, self.hidden_size)
        return create_variable(hidden)

Here’s the code for my train() function:

def train():
    total_loss = 0
    tokens_done = 0
    for i, (sequence, labels) in enumerate(zip(train_tokens, train_labels), 1):
        input, target = make_batch(sequence, labels, window_size=WINDOW_SIZE)
        output = classifier(input)
        loss = criterion(output, target).cuda()
        total_loss += loss.item()
        classifier.zero_grad()
        loss.backward()
        optimizer.step()
     return total_loss

The complete error log:

THCudaCheck FAIL file=/pytorch/aten/src/THC/THCTensorCopy.cu line=102 error=59 : device-side assert triggered
Traceback (most recent call last):
  File "last_rnn.py", line 189, in <module>
    train()
  File "last_rnn.py", line 116, in train
    output = classifier(input)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 479, in __call__
    result = self.forward(*input, **kwargs)
  File "last_rnn.py", line 96, in forward
    self.gru.flatten_parameters()
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py", line 113, in flatten_parameters
    self.batch_first, bool(self.bidirectional))
RuntimeError: cuda runtime error (59) : device-side assert triggered at /pytorch/aten/src/THC/THCTensorCopy.cu:102
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [6,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [7,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [8,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [9,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [10,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [11,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [12,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [13,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [14,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [15,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [16,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [17,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [18,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [19,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [20,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [21,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [22,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [23,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [24,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [25,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [26,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [27,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [28,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [29,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [30,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

This seems to be a recurring problem with pytorch but none of the solutions I’ve read was of any help.

Hi,
This error happens when you try to index a Tensor with an index that is bigger than it’s size. Unfortunately, because of the way CUDA works, it is hard to raise easily the error without loosing speed.
You should set CUDA_LAUNCH_BLOCKING=1 and rerun your program, that way, the python stack trace will point to the correct line.
A common reason for such issues is that your labels are badly formated for a sample (that appears late in the training) and it makes the loss computation (or Embedding layer) try to index too far in your score Tensor (or enbedding matrix). You might want to double check that all the labels (and inputs) in your dataset are correct.

2 Likes

As @albanD suggested, issue is with accessing element out of bound.
Simple approach to debug is add

print(input.shape)

before calling classifier.
My guess is make_batch is increasing sequence length

I tried the CUDA_LAUNCH_BLOCKING=1 command and the program indeed returned a different error:

/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:308: void indexSelectSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 2, SrcDim = 2, IdxDim = -2]: block: [0,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):
  File "last_rnn.py", line 191, in <module>
    train()
  File "last_rnn.py", line 116, in train
    output = classifier(input)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 479, in __call__
    result = self.forward(*input, **kwargs)
  File "last_rnn.py", line 96, in forward
    self.gru.flatten_parameters()
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py", line 113, in flatten_parameters
    self.batch_first, bool(self.bidirectional))
RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

As for the input shape, after checking it seems to be the correct size ([2, 5]). The others are of size [n, 5].

Hi! Have you solved the problem? i face the same problem, but i have not a idea

The previous issues were isolated by rerunning the code with CUDA_LAUNCH_BLOCKING=1 python script.pt args, which should print a stack trace with the line of code causing the issue.
As described in post, the failing operation was an out-of-bounds index, so the cudnn error was a red herring and was raised due to the asynchronous execution of CUDA operations.

Rerun the script with the mentioned env variable or rerun it on the CPU, which should also give you a clear error message.