Sequential MNIST Task RNN not learning

I am currently trying to validate my RNN written in base torch on the sequential mnist task, a task in which we flatten the regular mnist image from 28 by 28 pixels to a single vector of 28*28 = 784 pixels and feed each pixel to the model one by one, in an auto-regressive manner to perform mnist classification. The idea is to examine, how much can the RNN memorise the context or history of the sequentially added single pixel of an entire mnist image.

I have checked various things, like loss function, whether predictions are made, how pixels are fed into the model, however I was not able to understand why the model is not learning. I would be happy for any suggestions, ideas, or solutions, the community has to offer.
If you have worked on sequential Mnist before, even better. Please let me know. You can find my code below, which can be run in a regular google collab if copy pasted and will return training metric like training loss and training accuracy, both of which do not improve over multiple epochs:

Prints

Epoch:1   Train[Loss:2.3024 Top1 Acc:0.0986  Top5 Acc:0.5045]
Epoch:2   Train[Loss:2.3022 Top1 Acc:0.112  Top5 Acc:0.5045]
Epoch:3   Train[Loss:2.3021 Top1 Acc:0.1124  Top5 Acc:0.5122]
Epoch:4   Train[Loss:2.3019 Top1 Acc:0.1124  Top5 Acc:0.5139]
Epoch:5   Train[Loss:2.3018 Top1 Acc:0.1124  Top5 Acc:0.5163]

Notebook code

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import datasets
from torchvision import transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader


class RnnCell(nn.Module):
    def __init__(self, input_size, hidden_size, activation="tanh"):
        super(RnnCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.activation = activation
        if self.activation not in ["tanh", "relu", "sigmoid"]:
            raise ValueError("Invalid nonlinearity selected for RNN. Please use tanh, relu, or sigmoid.")

        self.input2hidden = nn.Linear(input_size, hidden_size)
        self.hidden2hidden = nn.Linear(hidden_size, hidden_size)

    def forward(self, input, carry, hidden_state=None):
        """
        Inputs: input (torch tensor) of shape [batchsize, input_size]
                hidden state (torch tensor) of shape [batchsize, hiddensize]
        Output: output (torch tensor) of shape [batchsize, hiddensize]
        """

        # Initialize hidden state at first iteration if none
        if hidden_state is None:
            hidden_state = torch.zeros(input.shape[0], self.hidden_size).to(device)
            carry = (hidden_state, hidden_state)

        # Carry
        h_t, _ = carry
        h_t = (self.input2hidden(input) + self.hidden2hidden(h_t))

        # Takes output from hidden and applies activation
        if self.activation == "tanh":
            out = torch.tanh(h_t)
        elif self.activation == "relu":
            out = torch.relu(h_t)
        elif self.activation == "sigmoid":
            out = torch.sigmoid(h_t)
        return (out, out)


class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, activation="relu"):
        super(SimpleRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.activation = activation

        if self.activation not in ["tanh", "relu", "sigmoid"]:
            raise ValueError("Invalid activation. Please use tanh, relu, or sigmoid activation.")

        self.rnn_cell = RnnCell(self.input_size, self.hidden_size, self.activation)
        self.fc = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden_state=None):
        """
        Inputs: input (torch tensor) of shape [batchsize, input_size]
        Output: output (torch tensor) of shape [batchsize, output_size]
        """

        # Initialize hidden state at first timestep if none
        if hidden_state is None:
            hidden_state = torch.zeros(input.shape[0], self.hidden_size).to(device)

        outs = []

        for t in range(input.size(1)):
            hidden_state, out = self.rnn_cell(input[:, t, :], (hidden_state, hidden_state))
            # collect output of rnn_cell to outs list
            outs.append(out)

        # Select last time step indexed at [-1]
        out = outs[-1].squeeze()

        out = self.fc(out)
        return out

def train(data_loader, model, optimizer, loss_f):
    loss_lst = []
    correct = 0
    total = 0
    model.train()

    for batch_idx, (x, y) in enumerate(data_loader):
        x, y = x.to(device), y.to(device)
        x = x.view(x.size(0), -1, 1)
        out = model(x)
        del x

        loss_val = loss_f(out, y)
        loss_lst.append(float(loss_val.item()))

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()
        del loss_val

        _, predicted = torch.max(out.data, 1)
        del out

        total += y.size(0)
        correct += (predicted == y).sum().item()

    # Compute average loss and accuracy
    loss_val = round(sum(loss_lst) / len(loss_lst), 4)
    accuracy = round(correct / total, 4)
    return loss_val, accuracy


device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 64
weight_decay = 0.0005
epochs = 20
nworkers = 2
lr = 0.0001
pin_memory = True
data_dir =  'data/'

train_dataset = datasets.MNIST(root = data_dir,
                                                train = True,
                                                transform = T.Compose([T.ToTensor()]),
                                                download = True)
train_loader = DataLoader(dataset = train_dataset,
                                            batch_size = batch_size,
                                            shuffle = True, drop_last = True)

model = SimpleRNN(input_size=1, hidden_size=64, output_size=10).to(device)
optimizer = optim.Adam(model.parameters(), lr = lr, weight_decay = weight_decay)
loss_f = nn.CrossEntropyLoss()
for epoch in range(epochs):
    train_loss_value, train_accuracy = train(train_loader, model, optimizer, loss_f)

    print(f"Epoch:{epoch + 1}   Train[Loss:{train_loss_value}  Accuracy:{train_accuracy}]")