My Model is pretty bad for MNIST and with different data my Loss and accuracy seems to not be training well at all

EDIT: Learning rate was too low i guess, 0.1 works better

I really need help with my code as I don’t know what could be wrong and I don’t know any more debugging options.
My model is definitely training but my accuracy is around 66% after 20 epochs and that’s with 2 CNN layers and 3 fully connected layers. I don’t know whats wrong!! I was using this model for different kind of image data and tried changing the code to fit the MNIST dataset but it is still not working as intended. Loss seems just way too high here. I played around with my learning rate and other optimizers but nothing seems to work better. With my initial data the accuracy was staying the same after each iteration, probably because of a comparison between different data types. This issue doesn’t seem to happen with MNIST but the model accuracy seems to be just garbage. Is there a layer that’s negatively affecting the output? Is the loss and accuracy being calculated with wrong data types and do I need to convert stuff?

%%time

from os import path

import numpy as np
import pandas as pd
import time

import torch
from torch.utils.data import Dataset, DataLoader, BatchSampler, Sampler
from torch import nn
import torchvision
from torchvision import datasets, transforms, models

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.net = nn.Sequential(
            #out = ((Input width - filter size + 2* padding) / Stride) + 1
            nn.Conv2d(1, 16, 3),
            nn.ReLU(),
            nn.MaxPool2d(5, 5),
            nn.Conv2d(16, 16, 3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(144, 120),
            nn.ReLU(),
            nn.Linear(120, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
            nn.Softmax()
        )
        
    def forward(self, x):
        #x = x.permute(0, 3, 1, 2)
        x = self.net(x)
        return x
    
    def accuracy(self, out, yb):
        preds = torch.argmax(out, dim=1)
        return (preds == yb).float().mean()

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        pass
        #torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight.data)

params = {'batch_size': 64,
          'shuffle': True}
max_epochs = 200

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
        
model = CNN()
model.cuda()
print(model)
model.apply(weights_init)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.0)


train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=64, shuffle=True)

start.record()
for epochs in range(max_epochs):
    # Training
    correct = 0
    total = 0
    print("start epoch. Correct:", correct, "total: ", total)
    for batch_idx, (data, target) in enumerate(train_loader):
        
        optimizer.zero_grad()
        
        target = target.to(device)
        
        prediction = model.forward(data.to(device))
        argmax_predicton = torch.argmax(prediction, dim = 1)
    
        loss = nn.functional.cross_entropy(prediction, target)
        
        total += target.size(0)
        
        correct += (argmax_predicton == target).sum().item()
        
        loss.backward()
        optimizer.step()
            
    print(f'Epoch: [{epochs+1}/{max_epochs}], Loss: {loss.item():.4f}, Acc: {correct} / {total}')

writer.close()
end.record()
# Waits for everything to finish running
torch.cuda.synchronize()

print(start.elapsed_time(end)/1000/60, "min")

The first few epochs look like this:
start epoch. Correct: 0 total: 0
Epoch: [1/200], Loss: 1.9858, Acc: 15113 / 60000
start epoch. Correct: 0 total: 0
Epoch: [2/200], Loss: 1.7330, Acc: 37511 / 60000
start epoch. Correct: 0 total: 0
Epoch: [3/200], Loss: 1.6763, Acc: 39267 / 60000
start epoch. Correct: 0 total: 0
Epoch: [4/200], Loss: 1.8991, Acc: 39999 / 60000
start epoch. Correct: 0 total: 0
Epoch: [5/200], Loss: 1.8558, Acc: 40489 / 60000
start epoch. Correct: 0 total: 0
Epoch: [6/200], Loss: 1.7643, Acc: 40777 / 60000
start epoch. Correct: 0 total: 0
Epoch: [7/200], Loss: 1.8476, Acc: 40991 / 60000
start epoch. Correct: 0 total: 0
Epoch: [8/200], Loss: 1.8249, Acc: 41139 / 60000
start epoch. Correct: 0 total: 0
Epoch: [9/200], Loss: 1.7904, Acc: 41278 / 60000
start epoch. Correct: 0 total: 0
Epoch: [10/200], Loss: 1.8400, Acc: 41384 / 60000