Preamble
Hello, I’m trying to solve a classification task where, given a text sample, I have to infer three labels: label_1_binary
, label_2_binary
and label_3_multiclass
. Two of them are binary i.e. they can only take one of two values. The third label is multi-class i.e. it can take one of n possible values.
My idea was to fine-tune BERT, retrieve the CLS token, feed it to a “common layer” and then use the output of the latter in three separate linear layers to retrieve the logits for the classification. Unfortunately, something is not working. The F1 scores I get are extremely low, even after taking care of the unbalanced classes and tuning the hyper-parameters with Ray-Tune.
Question
Being new to NLP and Pytorch, I’d appreciate your feedback on the following points.
- Is the structure of the model conceptually correct i.e. can the model be trained to return three separate logits for the labels I’m trying to infer?
- Or would it be better to have the model return a single logit and then process it with three separate classifiers e.g. MLP?
The model
class CustomModel(nn.Module):
def __init__(self, drop_rate: float = 0.15) -> None:
super().__init__()
self._language_model = BertModel.from_pretrained(
"dbmdz/bert-base-italian-uncased"
)
self._common_layer = nn.Sequential(
nn.Linear(768, 768),
nn.Tanh(),
nn.Dropout(drop_rate)
)
self._label_1_bin_layer = nn.Sequential(
nn.Linear(768, 1)
)
self._label_2_bin_layer = nn.Sequential(
nn.Linear(768, 1)
)
self._label_3_mul_layer = nn.Sequential(
nn.Linear(768, 4)
)
def forward(
self,
input_ids: torch.Tensor,
token_type_ids: torch.Tensor,
attention_mask: torch.Tensor
) -> Dict[str, torch.Tensor]:
CLS_representation = self._language_model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask
)[0][:, 0, :] # type: ignore
common_layer = self._common_layer(CLS_representation)
return {
"label_1_binary": self._label_1_bin_layer(common_layer),
"label_2_binary": self._label_2_bin_layer(common_layer),
"label_3_multiclass": self._label_3_mul_layer(common_layer)
}