Validation accuracy too high (85%) after 1 epoch of training on MNIST

Is this expected? I’m using a very basic Linear NN, and this feels too high for learning after 1 iteration.

Getting the data

train_transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
test_transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

training_set = torchvision.datasets.FashionMNIST('/kaggle/working/mnist', train=True, transform=train_transformer ,download=True)

training_set, validation_set = torch.utils.data.random_split(training_set, [int(0.8*len(training_set)), int(0.2*len(training_set))])

trainloader = torch.utils.data.DataLoader(training_set, batch_size=64, shuffle=True, pin_memory = True, num_workers = 3)
validationloader = torch.utils.data.DataLoader(validation_set, batch_size=64, shuffle=False, pin_memory = True, num_workers = 3)



Network

class Net(nn.Module):
    
    def __init__(self, num_features_in, num_features_out):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(num_features_in, 256)
        self.fc2 = nn.Linear(256,128)
        self.fc3 = nn.Linear(128,64)
        self.fc4 = nn.Linear(64, num_features_out)

    def forward(self, inp):

        inp = inp.view(inp.shape[0], -1) 
        x = F.relu(self.fc1(inp))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.log_softmax(self.fc4(x), dim=1)

return x

Training and Validation methods

def get_accuracy(predictions, labels):

    probs = torch.exp(predictions)
    actual_predictions = torch.argmax(probs,dim=1)
    equality = (labels == actual_predictions).sum()

    return equality.item()
def train(loader, losses, accuracies):


    running_loss = 0 
    accuracy = 0
    num_images = 0

    for images, labels in loader:
        optimizer.zero_grad() 

        images = images.to(device, non_blocking = True)
        labels = labels.to(device, non_blocking = True)


        predictions = net(images)
        loss = criterion(predictions, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        num_images += labels.shape[0]
        accuracy += get_accuracy(predictions, labels)

    avg_loss =  running_loss/len(loader)
    losses.append(avg_loss)
    print(f"Training loss per example/image: {avg_loss}")
    avg_accuracy = float(accuracy) / num_images
    accuracies.append(avg_accuracy)
    print(f"Accuracy per epoch: {avg_accuracy}")
def validate(testing, loader, losses, accuracies):
'''If testing =1 -> testing, otherwise validation'''


type_ = 'Testing' if testing else 'Validation'


net.eval()


with torch.no_grad():


    running_loss = 0 #this will be the training loss in an epoch
    accuracy = 0
    num_images = 0

    for images, labels in loader:

        images = images.to(device, non_blocking = True)
        labels = labels.to(device, non_blocking = True)

        # do a forward pass through the network to get the prediction
        predictions = net(images)

        # calculate the loss
        loss = criterion(predictions, labels)

        running_loss += loss.item()
        num_images += labels.shape[0]

        # calculating the accuracy
        accuracy += get_accuracy(predictions, labels)


    avg_loss =  running_loss/len(loader)
    losses.append(avg_loss)
    print(f"{type_} loss per example/image: {avg_loss}")
    avg_accuracy = float(accuracy) / num_images
    accuracies.append(avg_accuracy)
    print(f"Accuracy per epoch: {avg_accuracy}")

# put model back in training mode
net.train()

Main

################
##    Main    ##
################


net = Net(784,10).to(device)
criterion = nn.NLLLoss() 
optimizer = optim.Adam(net.parameters())


# Training
epochs = 35
train_losses = []
train_accuracies = []
validation_losses = []
validation_accuracies = []
for e in range(epochs):
    
    
    ################
    ## Training ####
    ################
    train(trainloader, train_losses, train_accuracies)
    save_checkpoint(net, optimizer, 0, e)

    
    ################
    ## Validation ##
    ################  
    validate(0, validationloader, validation_losses, validation_accuracies)

    print("epoch is: ", e)

save_checkpoint(net, optimizer, 1, e)

Output

Training loss per example/image: 0.543968736787637
Training Accuracy per epoch: 0.8029791666666667
Validation loss per example/image: 0.42178823727559533
Validation Accuracy per epoch: 0.8445
epoch is:  0
Training loss per example/image: 0.3901872558196386
Training Accuracy per epoch: 0.857125
Validation loss per example/image: 0.40278140400001344
Validation Accuracy per epoch: 0.8504166666666667
epoch is:  1
Training loss per example/image: 0.3478324411014716
Training Accuracy per epoch: 0.8726875
Validation loss per example/image: 0.38001503831053035
Validation Accuracy per epoch: 0.85375
epoch is:  2

Hi, @stroncea!
I’m not an expert either, but I think MNIST is a pretty ideal dataset. I’m not “taken aback” that a fully connected network would perform well, so maybe it’s ok?

With that said, even though your network is “basic”, it actually has a lot of trainable parameters, and will likely be slower to train, have a larger than necessary memory footprint, and be more prone to overfitting than a network that includes convolutions.

–SEH