Getting error, RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED while running a basic RNN model

This is my code for a basic RNN model and I am using the MNSIT dataset. My ultimate goal is to train this model on a custom dataset however I am trying to run this model on the MNSIT dataset so that I can be sure that the code and the model are running properly before I try to run my model.

When I run this model on my GPU I get the error that has been pasted below. However interestingly, When I run my model on CPU instead of my GPU, the model runs perfectly.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from import DataLoader
from import Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from load_data import IntentEstimationDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_size = 28
sequence_length = 28
num_layers = 2
hidden_size = 256
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 2

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size*sequence_length, num_classes)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, _ = self.rnn(x, h0)
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)
        return out

# Load data:
train_dataset = datasets.MNIST(root='/home/sharyat/catkin_ws/src/data_imu/script/', train=True,
                               transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='/home/sharyat/catkin_ws/src/data_imu/script/', train=False,
                              transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,shuffle=True)

# initialise model:
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)

# Loss and optimiser:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epochs in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        data =
        targets =

        # forward
        scores = model(data)
        loss = criterion(scores, targets)

        # backward


def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on training data')
        print('Checking accuracy on test data')

    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x, y in loader:
            x =
            y =
            x = x.reshape(x.shape[0], -1)

            scores = model(x)
            _, predictions = scores.max[1]
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

        print(f'Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100}')


check_accuracy(train_loader, model)

And After running the code for about 9 mins, I get this error:

Traceback (most recent call last):
  File "/home/sharyat/catkin_ws/src/data_imu/script/", line 63, in <module>
    scores = model(data)
  File "/home/sharyat/.local/lib/python3.8/site-packages/torch/nn/modules/", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/sharyat/catkin_ws/src/data_imu/script/", line 33, in forward
    out, _ = self.rnn(x, h0)
  File "/home/sharyat/.local/lib/python3.8/site-packages/torch/nn/modules/", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/sharyat/.local/lib/python3.8/site-packages/torch/nn/modules/", line 227, in forward
    result = _impl(input, hx, self._flat_weights, self.bias, self.num_layers,

Process finished with exit code 1

A few things I have considered after looking at other posts:

  1. My used memory for GPU is no where around the maximum memory.

After running Nvidia-smi I get:

Mon Dec 27 17:56:35 2021       
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| N/A   39C    P5    13W /  N/A |    996MiB /  5946MiB |     25%      Default |
|                               |                      |                  N/A |
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|    0   N/A  N/A      1054      G   /usr/lib/xorg/Xorg                 45MiB |
|    0   N/A  N/A      1668      G   /usr/lib/xorg/Xorg                312MiB |
|    0   N/A  N/A      1859      G   /usr/bin/gnome-shell              153MiB |
|    0   N/A  N/A      2193      G   ...AAAAAAAAA= --shared-files       29MiB |
|    0   N/A  N/A      2589      G   /usr/lib/firefox/firefox          443MiB |

After running nvcc --version I get:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243

And after running print(torch.version) in python, i get:

The cuda version mistmatch might be causing this, since PyTorch seems to be using 10.1 but your system seems to have 11.4 installed. Could you try to install Cuda toolkit 10.1 and see if that fixes the issue?

A newer driver should be able to run older CUDA runtimes and the local CUDA toolkit won’t be used when the pip wheels or conda binaries are installed as they ship with their own runtime.

@sharyat_singh could you update to the latest PyTorch release (1.10.1 or the nightly binary) and check if you are still hitting the issue?