MNIST Classifier in Pytorch - using prior to exclude class

Hello everyone

I’ve written pytorch code to us ResNet to predicting digits 0-9 from the standard MNIST database. It works pretty great. But I want to improve it. In particular, in my application, I usually have prior information which tells me that a particular image is not a particular value. For example, I know that an image being classified is not the number “6”.

Is there any known approach to use this prior information to restrict the classification so that the classifier output will give a 0% probability of the image being 6? I obviously know that I can just zero out the classifier output for label 6 and renormalize the other classes (i.e 0-9 excluding 6). But this is just a linear scaling on the final layer output.

I’d like to use this prior information at the input layer (?) or somewhere in the network to take advantage of all the nonlinear scaling of the network. I’m ok with writing a toy CNN from scratch (i.e. bypassing ResNet to begin).

Any advice is much appreciated. I know this is a bit off topic for pytorch forums only, but I’m hoping someone has some general advice.