Multilabel Multiclass Imbalanced Text Classifier

I have a list of patient symptom texts that can be classified as multi label with BERT. The problem is that there are thousands of classes (LABELS) and they are very imbalanced.

Input
The patient reports headache and fatigue

Output
Fatigue, headache

Here are some approaches I am considering:
1.OneVsRest Model + Datasets: Stack multiple OneVsRest BERT models with balanced OneVsRest datasets. Problem with it is that it is HUGE with so many stacked models. Additionally, pytorch doesn’t recognize the individual models as parameters when I assign them as a dictionary of layers.
2. OneVsRest Datasets: One for each individual outcome and then feed all to the same model with an additional entry denoting what outcome to predict.
3. Class weights: I tried this in the past and it didn’t seem to work well. Is it possible to have data that is too imbalanced for class weights?

Other Considerations:

  • As I understand, smote ml does not work for multilabel

You can try two options:

  1. Balance your dataset by augmenting it with composed samples(combining pieces of different reports in the same class).
  2. Use a self-supervised technique. For example, you can mask parts of a report and ask the model to predict those masked tokens. Then you model learns the overall structure of you text, and you can use few samples to fine-tune it.

Hey Arman,

Thanks for your input. I am new to transformers. I appreciate your help! In response to recommendations:

  1. This sounds to me like SMOTE. Is that correct? As I understand, SMOTE doesn’t work well with text data. Is that correct?
  2. This is an interesting idea. As I understand, I would do an encoder with a mask and then decode to find the mask. So input would look like this:

[CLS]The patient reports headache and fatigue[SEP]Fatigue, [MASK][SEP]

Wouldn’t this imply to the model that the output has an inherent order to it? (i.e. Fatigue comes before headache)

  1. I agree that SMOTE does not generally work for text data. But a few tricks might make it work. For example, instead of randomly mixing different parts of multiple texts, you may concatenate two pieces of text from different classes, and expect your model to output the same probability for those classes. Nothing is for sure, the best option is to try!

  2. I assume you are using a pre-trained BERT. BERT is just a stack of encoders, so there is no decoder! What you are going to do is a Masked Language Modelling, which you can refer to Huggingface documentation for detailed information. In short, you mask some tokens and ask the model to predict those tokens.

  3. About your last point, I agree that this training regime might introduce bias in your model. You may want to preprocess your data to reduce this bias in the dataset. As a side note, I think there are better options than BERT for this task. For instance, you can extract tf-idf features from your documents and use a simple SVM classifier, which is robust to word positions.

Thank you for your advice! I will try these approaches

Jacob Clarke

You might want to look at meta-learning (e.g. MAML). I know if there has been any work done with BERT encoders or NLP extraction specifically, but this general methodology should apply to your context as well. Honestly, I think dealing with imbalanced data is really a nontrivial problem :wink: