How to train a classifier to return logits for three separate labels

Hello, I’m trying to solve a classification task where, given a text sample, I have to infer three labels: label_1_binary, label_2_binary and label_3_multiclass. Two of them are binary i.e. they can only take one of two values. The third label is multi-class i.e. it can take one of n possible values.

My idea was to fine-tune BERT, retrieve the CLS token, feed it to a “common layer” and then use the output of the latter in three separate linear layers to retrieve the logits for the classification. Unfortunately, something is not working. The F1 scores I get are extremely low, even after taking care of the unbalanced classes and tuning the hyper-parameters with Ray-Tune.

Being new to NLP and Pytorch, I’d appreciate your feedback on the following points.

  1. Is the structure of the model conceptually correct i.e. can the model be trained to return three separate logits for the labels I’m trying to infer?
  2. Or would it be better to have the model return a single logit and then process it with three separate classifiers e.g. MLP?

The model

class CustomModel(nn.Module):
    def __init__(self, drop_rate: float = 0.15) -> None:

        self._language_model = BertModel.from_pretrained(
        self._common_layer = nn.Sequential(
            nn.Linear(768, 768),
        self._label_1_bin_layer = nn.Sequential(
            nn.Linear(768, 1)
        self._label_2_bin_layer = nn.Sequential(
            nn.Linear(768, 1)
        self._label_3_mul_layer = nn.Sequential(
            nn.Linear(768, 4)

    def forward(
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        CLS_representation = self._language_model(
        )[0][:, 0, :] # type: ignore

        common_layer = self._common_layer(CLS_representation)

        return {
            "label_1_binary": self._label_1_bin_layer(common_layer),
            "label_2_binary": self._label_2_bin_layer(common_layer),
            "label_3_multiclass": self._label_3_mul_layer(common_layer)

Yes, creating multiple classifier “heads” is a valid approach. I don’t know how your data is distributed but would recommend scaling down the use case to try to overfit a small dataset (e.g. just 10 samples) by adapting the hyperparameters.