Using priors for inference?

I have succeeded to train a simple neural net to predict k=3 labels (original dataset is imbalanced, I downsampled the dominant class to make it somewhat balanced, and use weights in the loss function during training).

I want to use this trained network for inference on a new dataset of which I know the real distribution of (imbalanced) labels.

How can I use this knowledge of the distribution of labels in the new dataset I want to infer on?

perhaps KLDivLoss would allow this, not sure.

Hey thanks, since I’m doing inference not sure any loss function might be involved, unless I’m missing something.

Interesting problem :slight_smile:

In this case, there’s two ideas that come to mind:

  1. Don’t balance your dataset classes. Specifically, balance your classes to represent the real distribution of labels. Typically, we would balance our classes because we’re working off the assumption that we want our model to be equally good or bias towards every class. However, if we have prior knowledge of our distribution, we know that we do in fact want our model to be biased towards, for example, predicting dog if we know that 80% of the time its a dog. However, there will be some trade-offs. Your model will be strictly worse at predicting these other classes. But you will get better accuracy in your target domain.

  2. I don’t have much knowledge of this area, but you can also artificially scale your logits either after the fact or during training. In binary classification, researchers often adjust the threshold to induce a bias of some kind (a cancer detecting model that prefers false-positives than false-negatives). We will still train as normal, we’ll just adjust the cut-off threshold after the fact to get the bias we’re looking for. I’m not sure what this would like in the multi-class case however. I’m assuming there’s a multi-class variant

Thank you for your ideas. For 1), there is not a single distribution to be used as prior, there are many different ones because I’m making inferences in various geographies with different final distributions. I think using a single imbalanced proportion as weights during training will likely degrade the inferences, that’s why I went with the balanced case :slight_smile:

Nah, I misunderood. For “no retraining”, suggested shifting of logits may work, that basically rescales one-vs-rest odds ratios, you can check math on wiki. It is also similar to probability calibration (Platt’s scaling), but as you don’t re-train, you’ll be using a constant rescaler.