Classification - Possible to predict unknown class instead of wrong class?

Hi,
I’m working on an interesting project where I prefer to receive some kind of “unknown” class if the model is uncertain which class it is and the correct class if the model is certain it’s the correct class.

I think it could be done by a costume/modified loss function which will not give a penalty if the class is the correct one, will give a large penalty in the class is wrong (which is basically a CrossEntropyLoss until now) and will give a small penalty if the model will predict it as “unknown”.

There are some counter arguments that could be raised - If I have two classes for example and the regular softmax prediction is between 0 (first class) and 1 (second class) then I can use the regular values and decide to take every prediction below 0.1 as class 1, above 0.9 as class 2 and everything in between as class “unknown”. But from what I observed, the model will almost always converge to return values around 0-0.1 and 0.99-1 while almost never return values in between, and the difference between an accurate model and inaccurate is the percentage of correct predictions, but the values are still near the two extremes so this approach is not useful (unless I’m doing something wrong and there is a technique to force the model to return values across all the range between 1-0 corresponding to the model’s confidence in the prediction)
Another counter point is that the new loss function might force the model to always predict the “unknown” class as the simplest way to reduce loss, and it will fixate on this prediction without learning anything. I’m not sure if it will happen and how to combat this if it does happen, maybe I could manipulate the constants of the loss function, or use a pre-trained that can predict some correct classes.

My question is whether something like this was done or attempted before? I didn’t find anything on google/here/literature but that’s maybe because I used the wrong terminology…

Additionally, if someone can help me with the actual code for the custom/modified loss function and give me some basic code to build upon it will be great.

Thanks.