Im trying to classificate multilabels for sentiment analysis.
first i separate the data and pass the batch on BERT layer, that encodes and generate embedding.
So i create a NN that uses this data from BERT to classify and output a 16 layer data (considering that im trying to classificate accordingly MTBI taxonomy).
So i use the BCEWithLogitsLoss to see the generated data and compare with multilabel classification.
I convert my labels to 0 to 15, and then get a 16x16 matrix that indicates the position of desired classification.
But im thinking: im using the BCEWithLogitsLoss which compares values between 0 and 1. But now it’s only relevant to me the “1” values.
Should i create a custom loss function to use? Or follow other approach?
The code is:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
def set_bert_required_grad(value:bool = True):
for param in bert_model.parameters():
param.requires_grad = value
class TextDataset(Dataset):
def __init__(self, texts):
self.texts = texts
def __len__(self):
return len(self.texts)
def __getitem__(self, index):
return self.texts[index]
class PersonalityDetectionModel(nn.Module):
def __init__(self):
super(PersonalityDetectionModel, self).__init__()
self.dropout = nn.Dropout(0.4)
self.fc = nn.Linear(BERT_VARIANTS_CLS_LAYER_SIZE, 512)
self.attention = nn.MultiheadAttention(embed_dim=512, num_heads=16, dropout=0.2, device=device)
self.fc1 = nn.Linear(512, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 16)
self.relu = nn.ReLU()
def forward(self, posts):
posts = self.dropout(posts)
posts = self.relu(self.fc(posts))
posts = posts.unsqueeze(1) # Prepare for multihead attention (add sequence dimension)
posts, _ = self.attention(posts, posts, posts)
posts = posts.squeeze(1) # Remove sequence dimension
posts = self.relu(self.fc1(posts))
posts = self.dropout(posts)
posts = self.relu(self.fc2(posts))
posts = self.dropout(posts)
posts = self.fc3(posts)
return posts
def encode_batch(texts_batch):
encoded_inputs = tokenizer(texts_batch, padding=True, truncation=True, return_tensors="pt").to(device)
output = bert_model(**encoded_inputs)
return output.last_hidden_state[:, 0, :] # cls token
# One hot encoding
def labels_to_multilabel(local_labels, num_classes=16):
multilabels = torch.zeros((local_labels.size(0), num_classes), device=local_labels.device)
for idx, label in enumerate(local_labels):
multilabels[idx][label] = 1
return multilabels
batch_size = 16
model = PersonalityDetectionModel().to(device)
optimizer = torch.optim.Adam([
{'params': model.parameters(), 'lr': 1e-3},
{'params': bert_model.parameters(), 'lr': 5e-5}
])
criterion = nn.BCEWithLogitsLoss()
type_to_label = {
"INTJ": 0,
"INTP": 1,
"INFJ": 2,
"INFP": 3,
"ENTJ": 4,
"ENTP": 5,
"ENFJ": 6,
"ENFP": 7,
"ISTJ": 8,
"ISFJ": 9,
"ISTP": 10,
"ISFP": 11,
"ESTJ": 12,
"ESFJ": 13,
"ESTP": 14,
"ESFP": 15
}
train_texts, val_texts, train_labels, val_labels = train_test_split(
df['posts'].to_list(),
df['type'].map(type_to_label).values,
test_size=0.8,
random_state=1337,
shuffle=True
)
train_dataset = TextDataset(train_texts)
val_dataset = TextDataset(val_texts)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
def train_model():
set_bert_required_grad()
model.train()
bert_model.train()
for epoch in range(1):
total_correct = 0
total_samples = 0
for i, texts_batch in enumerate(train_dataloader):
encoded_batch = encode_batch(texts_batch)
local_labels = torch.tensor(train_labels[i * batch_size : (i + 1) * batch_size]).to(device)
multi_labels = labels_to_multilabel(local_labels)
optimizer.zero_grad()
outputs = model(encoded_batch)
loss = criterion(outputs, multi_labels)
loss.backward()
optimizer.step()
probs = torch.sigmoid(outputs)
predicted_labels = (probs > 0.5).float()
correct_predictions = ((predicted_labels == 1) & (multi_labels == 1)).float().sum().item()
total_positives = multi_labels.sum().item()
total_correct += correct_predictions
total_samples += total_positives
if i % 10 == 0:
accuracy = total_correct / total_samples if total_samples > 0 else 0.0
print(f"Epoch [{epoch+1}/10], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}, Accuracy: {accuracy:.4f}")
epoch_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
print(f"Epoch [{epoch+1}/10] Accuracy: {epoch_accuracy:.4f}")