Hello,
I currently have a classification task, except I use an object detection model. The model outputs detections for all objects it locates in the image, and assigns a confidence to each. However, I am only interested in classifying the entire image as one class, so I will need to perform some postprocessing.
I first have a vector of K threshold values, one for each class, and I threshold all confidences such that if a confidence is less than the corresponding class threshold, I clamp it to zero (essentially deleting the detection).
I also have a vector of K class weights, one for each class, and I multiply all confidences by the corresponding value. This is used to “re-order” the confidences so to improve the accuracy. For example, a certain class could be detected with low confidences by the model, but is a high priority for the user (such that when it appears, the user would always want the image to be classified as it).
I currently have the mathematical problem where I want to find the optimal values for both the thresholds and the weights, such that with these new parameters, I can alter the confidences such that the highest accuracy is achieved (the highest confidence detection, after thresholding/weighting, in each image coincides with the ground truth).
I know that this scheme runs a lot of risks with regards to overfitting, or that accuracy is not the best metric when the dataset is well-balanced, etc. but I am purely interested in this as a mathematical question.
Currently, for the weight optimization, I am running a gradient descent algorithm using a binary cross-entropy loss. I use binary cross entropy because I cannot handle the “None” (no detections) class using regular cross-entropy, as far as I know (please correct me if I’m wrong). This works with decent results.
But I am not sure how to optimize the thresholds, as it’s non-differentiable, short of using a brute-force or heuristic algorithm like genetic algorithm or grid search. I have tried using an iterative local search but it is slow and is slow to improve.
I’ve been struggling with this problem for a while, and any help would be greatly appreciated.
Thanks!