Fine tuning electra for text classification is giving awful results

Hey everyone, I’ve been trying to do text classification via finetuning a huggingface model (I chose electra cuz of its small size). I’m getting awful results where it basically prints out almost the same logits for all the validation inputs and gets either 0 accuracy or 42% accuracy depending on if a logit has a value > 0.5. I’m very confused and need some help understanding why.

import warnings
import pandas as pd
from transformers import ElectraTokenizer, ElectraForSequenceClassification
import torch
import numpy as np
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Suppress FutureWarnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Load your CSV data
train_file = 'C:\\temp\\train_new.csv'
test_file = 'C:\\temp\\test_new.csv'
data = pd.read_csv(train_file)

data.fillna("n/a", inplace=True)
print ("data.shape", data.shape)

# Split the multi-labels into individual labels
data['Mapped Keywords'] = data['Mapped Keywords'].str.split(', ')

# Initialize MultiLabelBinarizer
mlb = MultiLabelBinarizer()
targets = mlb.fit_transform(data['Mapped Keywords'])

# Tokenize your data
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')

# Concatenate values from columns 2 through 8, ignoring "n/a" values and extra whitespace
concatenated_text = data.iloc[:, 1:8].apply(lambda x: ' '.join(filter(lambda y: y != 'n/a', x)), axis=1).str.strip()

inputs = tokenizer(concatenated_text.tolist(), padding=True, truncation=True, return_tensors="pt")

# Define your dataset
class CustomDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets
        
    def __len__(self):
        return len(self.inputs['input_ids'])

    def __getitem__(self, idx):
        input_item = {key: val[idx] for key, val in self.inputs.items()}  # Tokenized inputs for the idx
        target_item = torch.tensor(self.targets[idx], dtype=torch.float32)  # Convert target to tensor
        return input_item, target_item

# Define your dataset
overall_data = CustomDataset(inputs, targets)
print (len(overall_data))

# Split your data into training and validation sets
train_idx, val_idx = train_test_split(range(len(overall_data)), test_size=0.2, random_state=2724)
train_data = DataLoader([overall_data[i] for i in train_idx], batch_size=20, shuffle=True)
val_data = DataLoader([overall_data[i] for i in val_idx], batch_size=1)
print ("lentraindata:", len(train_data))

# Define your model architecture
model = ElectraForSequenceClassification.from_pretrained('google/electra-small-discriminator', num_labels=len(mlb.classes_))

# Define your optimizer
optimizer = AdamW(model.parameters(), lr=5e-4)

# Define your loss function for multi-label classification
loss_function = nn.BCEWithLogitsLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train your model
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    print('training')
    for batch_idx, batch in enumerate(train_data):
        inputs, targets = batch
        inputs = {key: val.to(device) for key, val in inputs.items()}
        targets = targets.to(device)
        
        outputs = model(**inputs)
        optimizer.zero_grad()
        outputs.loss.backward()
        optimizer.step()
        logits = outputs.logits
        
        # Compute the loss using binary cross-entropy with logits loss
        loss = loss_function(logits, targets)
        print(loss)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

# Evaluate your model
model.eval()

val_predictions = []
val_targets = []

with torch.no_grad():
    for batch_idx, batch in enumerate(val_data):
        inputs, targets = batch
        inputs = {key: val.to(device) for key, val in inputs.items()}
        targets = targets.to(device)
        
        outputs = model(**inputs)
        logits = outputs.logits
        print ("logit", batch_idx, logits)
        predictions = torch.sigmoid(logits)  # Apply sigmoid to logits for multi-label classification
        predictions = (predictions >= 0.5).float()  # Threshold predictions
        
        val_predictions.extend(predictions.cpu().numpy())
        val_targets.extend(targets.cpu().numpy())

print ("val pred:", val_predictions)
print ("-------------")
print ("val targ:", val_targets)
# Calculate metrics
val_accuracy = accuracy_score(val_targets, val_predictions)
val_precision = precision_score(val_targets, val_predictions, average='weighted')
val_recall = recall_score(val_targets, val_predictions, average='weighted')
val_f1 = f1_score(val_targets, val_predictions, average='weighted')

print("Validation Accuracy:", val_accuracy)
print("Validation Precision:", val_precision)
print("Validation Recall:", val_recall)
print("Validation F1-score:", val_f1)

I know rnns are vastly superior to electra, would it be worth switching to that?