To this model I use a transformer (distillbert) and apply to it weighted average in 4 layers. After that I combine it with some features I got and run it through mlp. I have multi class classification but I always get the same results
acc 0.7270668176670442
matthews_corrcoef 0.0
f1 0.6121664222193344
f1_macro 0.21049180327868855
f1_micro 0.7270668176670442
f1_classwise [0. 0.84196721 0. 0. ]
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
class WeightedAverageTransformerMLP(nn.Module):
def __init__(
self,
model_name,
num_extra_dims,
num_labels,
number_of_layers_to_concat=4,
hidden_size=256,
):
super().__init__()
self.number_of_layers_to_concat = number_of_layers_to_concat
self.config = AutoConfig.from_pretrained(model_name)
self.config.update({"output_hidden_states": True})
self.transformer = AutoModel.from_pretrained(model_name, config=self.config)
self.layer_weights = nn.Parameter(
torch.tensor([1] * self.number_of_layers_to_concat, dtype=torch.float)
)
self.dropout = nn.Dropout(0.1)
self.mlp = nn.Sequential(
nn.Linear(self.transformer.config.hidden_size + num_extra_dims, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, num_labels),
)
def forward(self, input_ids, extra_features, attention_mask=None):
# with torch.no_grad():
outputs = self.transformer(
input_ids=input_ids, attention_mask=attention_mask
)
hidden_states = torch.stack(outputs.hidden_states)
cls_embeddings = self._get_weighted_average(hidden_states)
combined_features = torch.cat((cls_embeddings, extra_features), dim=-1)
combined_features = self.dropout(combined_features)
logits = self.mlp(combined_features)
return logits
def _get_weighted_average(self, hidden_states):
chosen_layers = hidden_states[-self.number_of_layers_to_concat :, :, :, :]
# turn layer weights into proper shape
weight_factor = (
self.layer_weights.unsqueeze(-1)
.unsqueeze(-1)
.unsqueeze(-1)
.expand(chosen_layers.size())
)
weighted_average = (weight_factor * chosen_layers).sum(
dim=0
) / self.layer_weights.sum()
# keep only the first token of the sequence (CLS token) (batch, 768)
cls_embeddings = weighted_average[:, 0, :]
return cls_embeddings
Also this is my training function:
def train_batch(data_loader, model, loss_fn, optimizer, device):
model.train()
for batch in tqdm(
data_loader, total=len(data_loader), leave=False, desc="Training Batches"
):
optimizer.zero_grad()
input_ids = batch["input_ids"].squeeze(1).to(device)
attention_mask = batch["attention_mask"].squeeze(1).to(device)
features = batch["features"].to(device)
labels = batch["label"].type(torch.LongTensor).to(device)
logits = model(
input_ids=input_ids, extra_features=features, attention_mask=attention_mask
)
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
and the testing:
def test_batch(data_loader, model, loss_fn, device):
size = len(data_loader.dataset)
num_batches = len(data_loader)
model.eval()
test_loss, correct = 0, 0
all_predictions = []
all_labels = []
with torch.no_grad():
for batch in tqdm(
data_loader, total=len(data_loader), leave=False, desc="Testing Batches"
):
input_ids = batch["input_ids"].squeeze(1).to(device)
attention_mask = batch["attention_mask"].squeeze(1).to(device)
features = batch["features"].to(device)
labels = batch["label"].type(torch.LongTensor).to(device)
logits = model(
input_ids=input_ids,
extra_features=features,
attention_mask=attention_mask,
)
test_loss += loss_fn(logits, labels).item()
correct += (logits.argmax(1) == labels).type(torch.float).sum().item()
probs = F.softmax(logits, dim=1)
predictions = torch.argmax(logits, dim=1)
all_predictions.extend(torch.argmax(probs, dim=1).cpu().tolist())
all_labels.extend(labels.cpu().tolist())
test_loss /= num_batches
correct /= size
f1 = f1_score(all_labels, all_predictions, average='weighted')
f1_macro = f1_score(all_labels, all_predictions, average='macro')
f1_micro = f1_score(all_labels, all_predictions, average='micro')
f1_classwise = f1_score(all_labels, all_predictions, average=None)
matthews = matthews_corrcoef(all_labels, all_predictions)
acc = accuracy_score(all_labels, all_predictions)
print(
f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
print("acc", acc)
print("matthews_corrcoef", matthews)
print("f1", f1)
print("f1_macro", f1_macro)
print("f1_micro", f1_micro)
print("f1_classwise", f1_classwise)
and the rest:
classifier = WeightedAverageTransformerMLP(classifier_name, 30, 4).to(device)
loss_fn = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(classifier.parameters())