Classification linear model from scratch

Short intro: I don’t know a lot about the math behind machine learning, but I think I understand the basics and the general idea of it. At the moment I am following along the fastai’s course/book and decided that I’d like to write a multi-class linear classification model from scratch as a learning excercise.

I have a working version of this simple model which classifies 28*28 pictures of hand-written digits from the MNIST dataset. The highest accuracy I’ve been able to achieve is ~58%. Here’s a link to it on Kaggle: Digits recognizer | Kaggle

I would like to ask a few questions in order to understand what I could improve on:

  1. As I understand it, a model with random weights should have an accuracy of about 10%, so an accuracy of 58% should indicate that my model is doing something right and isn’t just blindly guessing, correct?

  2. If my model is indeed working, how good is the accuracy of 58%? What is the reasonable accuracy one could expect from a model with a singular layer and no non-linearities?

  3. I have picked learning rates at random and chose the ones that gave good accuracy. What better way of picking learning rates would you suggest?

And finally, I would really appreciate it if someone were to look at my model and give some advice on how to improve it or my coding style. I am more than willing to provide some comments/explanations about my model in case my code is confusing.

Based on your code it seems you are using a few manual implementations, which might suffer from numerical stability issues. E.g. replace the softmax().log() with F.log_softmax for a better numerical stability. Also, could you explain why argmin is used to compute the predictions instead of argmax?

Thank you for the suggestion.
I am aware of the built-in log_softmax function, but decided to write a manual implementation to see how well it would work. I had to introduce some correction terms such as b in softmax() to deal with exp(x) = inf overflow and epsilon in nll() for log(0) = -inf. Are there any other numerical stability issues I am not aware of?

As for the usage of argmin(): nll() function, the negative log likelihood, returns only positive values, where 0 corresponds to 100% confidence/probability and +inf to 0 confidence, so I use argmin() during the validation phase to pick the lowest value, which is the model’s “best guess” and see if it was correct. That’s how I understood it based on chapter 5 in fastai’s book (, please correct me if I’m wrong.