HELP with multilabel classification and BCEWithLogitsLoss

Im trying to classificate multilabels for sentiment analysis.

first i separate the data and pass the batch on BERT layer, that encodes and generate embedding.
So i create a NN that uses this data from BERT to classify and output a 16 layer data (considering that im trying to classificate accordingly MTBI taxonomy).

So i use the BCEWithLogitsLoss to see the generated data and compare with multilabel classification.

I convert my labels to 0 to 15, and then get a 16x16 matrix that indicates the position of desired classification.

But im thinking: im using the BCEWithLogitsLoss which compares values between 0 and 1. But now it’s only relevant to me the “1” values.

Should i create a custom loss function to use? Or follow other approach?

The code is:

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)

def set_bert_required_grad(value:bool = True):
    for param in bert_model.parameters():
        param.requires_grad = value

class TextDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts

    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, index):
        return self.texts[index]

class PersonalityDetectionModel(nn.Module):
    def __init__(self):
        super(PersonalityDetectionModel, self).__init__()

        self.dropout = nn.Dropout(0.4)


        self.fc = nn.Linear(BERT_VARIANTS_CLS_LAYER_SIZE, 512)
        self.attention = nn.MultiheadAttention(embed_dim=512, num_heads=16, dropout=0.2, device=device)

        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 16)

        self.relu = nn.ReLU()

    def forward(self, posts):        
        posts = self.dropout(posts)
        posts = self.relu(self.fc(posts))

        posts = posts.unsqueeze(1)  # Prepare for multihead attention (add sequence dimension)
        posts, _ = self.attention(posts, posts, posts)
        posts = posts.squeeze(1)  # Remove sequence dimension

        posts = self.relu(self.fc1(posts))
        posts = self.dropout(posts)

        posts = self.relu(self.fc2(posts))
        posts = self.dropout(posts)
        
        posts = self.fc3(posts)
        return posts

def encode_batch(texts_batch):
    encoded_inputs = tokenizer(texts_batch, padding=True, truncation=True, return_tensors="pt").to(device)
    output = bert_model(**encoded_inputs)
    return output.last_hidden_state[:, 0, :]  # cls token

# One hot encoding
def labels_to_multilabel(local_labels, num_classes=16):
    multilabels = torch.zeros((local_labels.size(0), num_classes), device=local_labels.device)
    
    for idx, label in enumerate(local_labels):
        multilabels[idx][label] = 1
    
    return multilabels

batch_size = 16

model = PersonalityDetectionModel().to(device)

optimizer = torch.optim.Adam([
    {'params': model.parameters(), 'lr': 1e-3},
    {'params': bert_model.parameters(), 'lr': 5e-5}
])

criterion = nn.BCEWithLogitsLoss()

type_to_label = { 
    "INTJ": 0,
    "INTP": 1,
    "INFJ": 2,
    "INFP": 3,
    "ENTJ": 4,
    "ENTP": 5,
    "ENFJ": 6,
    "ENFP": 7,
    "ISTJ": 8,
    "ISFJ": 9,
    "ISTP": 10,
    "ISFP": 11,
    "ESTJ": 12,
    "ESFJ": 13,
    "ESTP": 14,
    "ESFP": 15
}

train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['posts'].to_list(), 
    df['type'].map(type_to_label).values, 
    test_size=0.8, 
    random_state=1337,
    shuffle=True
)

train_dataset = TextDataset(train_texts)
val_dataset = TextDataset(val_texts)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

def train_model():
    set_bert_required_grad()

    model.train()
    bert_model.train()

    for epoch in range(1):
        total_correct = 0                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
        total_samples = 0
        for i, texts_batch in enumerate(train_dataloader):
            encoded_batch = encode_batch(texts_batch)
            local_labels = torch.tensor(train_labels[i * batch_size : (i + 1) * batch_size]).to(device)

            multi_labels = labels_to_multilabel(local_labels)

            optimizer.zero_grad()

            outputs = model(encoded_batch)

            loss = criterion(outputs, multi_labels)

            loss.backward()

            optimizer.step()

            probs = torch.sigmoid(outputs)

            predicted_labels = (probs > 0.5).float()

            correct_predictions = ((predicted_labels == 1) & (multi_labels == 1)).float().sum().item()

            total_positives = multi_labels.sum().item()

            total_correct += correct_predictions
            total_samples += total_positives 

            if i % 10 == 0: 
                accuracy = total_correct / total_samples if total_samples > 0 else 0.0
                print(f"Epoch [{epoch+1}/10], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}, Accuracy: {accuracy:.4f}")

        epoch_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
        print(f"Epoch [{epoch+1}/10] Accuracy: {epoch_accuracy:.4f}")

BCEWithLogitsLoss is for Binary 0 or 1 use torch.nn.CrossEntropyLoss