I wanted to build a simple ANN and train it from scratch on the Mnist dataset. The accuracy values look fine as expected but the loss is just way too high, as if I was not computing it correctly. What am I missing in the code below?

```
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.set_grad_enabled(True) #by default gradients are computed
from itertools import product
from sklearn.metrics import accuracy_score
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
# hidden layers
self.h1 = nn.Linear(in_features=784, out_features=196)
self.h2 = nn.Linear(in_features=196, out_features=49)
# output layer, 10 units - one for each digit
self.output = nn.Linear(49, 10)
def forward(self, x):
# (0) input layer
x = x
# (1) first hidden layer
x = self.h1(x)
x = F.relu(x)
# (2) second hidden layer
x = self.h2(x)
x = F.relu(x)
# (3) output layer
x = self.output(x)
# x = F.softmax(x, dim=1) -> already in cross-entropy loss function??
return x
def get_num_correct(preds, labels):
return preds.argmax(dim=1).eq(labels).sum().item()
train_set = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
parameters = dict(lr = [.01], batch_size = [32, 64])
param_values = [v for v in parameters.values()]
for lr, batch_size in product(*param_values):
comment = f' batch_size={batch_size} lr={lr}'
print(comment)
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
optimizer = optim.Adam(network.parameters(), lr=lr)
images, labels = next(iter(train_loader))
for epoch in range(3):
total_loss = 0
total_correct = 0
all_preds = torch.Tensor([])
for batch in train_loader:
images, labels = batch # Get Batch
preds = network(images.unsqueeze(0).reshape(-1, 28*28*1)) # Pass batch
all_preds = torch.cat((all_preds, preds), dim=0)
loss = F.cross_entropy(preds, labels) # Calculate Loss
optimizer.zero_grad() # Zero Gradients
loss.backward() # Calculate Gradients
optimizer.step() # Update Weights
total_loss += loss.item() * batch_size # is this correct? why?
total_correct += get_num_correct(preds, labels)
print(f"""epoch: {epoch}; total loss: {total_loss};
accuracy: {accuracy_score(train_loader.dataset.targets, all_preds.argmax(dim=1))}""")
```