Same metrics in 11 epochs (Not training)

To this model I use a transformer (distillbert) and apply to it weighted average in 4 layers. After that I combine it with some features I got and run it through mlp. I have multi class classification but I always get the same results
acc 0.7270668176670442
matthews_corrcoef 0.0
f1 0.6121664222193344
f1_macro 0.21049180327868855
f1_micro 0.7270668176670442
f1_classwise [0. 0.84196721 0. 0. ]

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel

class WeightedAverageTransformerMLP(nn.Module):
    def __init__(
        self,
        model_name,
        num_extra_dims,
        num_labels,
        number_of_layers_to_concat=4,
        hidden_size=256,
    ):
        super().__init__()
        self.number_of_layers_to_concat = number_of_layers_to_concat

        self.config = AutoConfig.from_pretrained(model_name)
        self.config.update({"output_hidden_states": True})

        self.transformer = AutoModel.from_pretrained(model_name, config=self.config)

        self.layer_weights = nn.Parameter(
            torch.tensor([1] * self.number_of_layers_to_concat, dtype=torch.float)
        )

        self.dropout = nn.Dropout(0.1)
        self.mlp = nn.Sequential(
            nn.Linear(self.transformer.config.hidden_size + num_extra_dims, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_labels),
        )

    def forward(self, input_ids, extra_features, attention_mask=None):
        # with torch.no_grad():
        outputs = self.transformer(
            input_ids=input_ids, attention_mask=attention_mask
        )

        hidden_states = torch.stack(outputs.hidden_states)

        cls_embeddings = self._get_weighted_average(hidden_states)

        combined_features = torch.cat((cls_embeddings, extra_features), dim=-1)

        combined_features = self.dropout(combined_features)

        logits = self.mlp(combined_features)

        return logits

    def _get_weighted_average(self, hidden_states):
        chosen_layers = hidden_states[-self.number_of_layers_to_concat :, :, :, :]

        # turn layer weights into proper shape
        weight_factor = (
            self.layer_weights.unsqueeze(-1)
            .unsqueeze(-1)
            .unsqueeze(-1)
            .expand(chosen_layers.size())
        )

        weighted_average = (weight_factor * chosen_layers).sum(
            dim=0
        ) / self.layer_weights.sum()

        # keep only the first token of the sequence (CLS token) (batch, 768)
        cls_embeddings = weighted_average[:, 0, :]
        return cls_embeddings

Also this is my training function:

def train_batch(data_loader, model, loss_fn, optimizer, device):
    model.train()
    for batch in tqdm(
        data_loader, total=len(data_loader), leave=False, desc="Training Batches"
    ):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].squeeze(1).to(device)
        attention_mask = batch["attention_mask"].squeeze(1).to(device)
        features = batch["features"].to(device)
        labels = batch["label"].type(torch.LongTensor).to(device)

        logits = model(
            input_ids=input_ids, extra_features=features, attention_mask=attention_mask
        )
        
        loss = loss_fn(logits, labels)

        loss.backward()
        optimizer.step()

and the testing:

def test_batch(data_loader, model, loss_fn, device):
    size = len(data_loader.dataset)
    num_batches = len(data_loader)
    model.eval()
    test_loss, correct = 0, 0

    all_predictions = []
    all_labels = []


    with torch.no_grad():
        for batch in tqdm(
        data_loader, total=len(data_loader), leave=False, desc="Testing Batches"
    ):
            input_ids = batch["input_ids"].squeeze(1).to(device)
            attention_mask = batch["attention_mask"].squeeze(1).to(device)
            features = batch["features"].to(device)
            labels = batch["label"].type(torch.LongTensor).to(device)

            logits = model(
                input_ids=input_ids,
                extra_features=features,
                attention_mask=attention_mask,
            )

            test_loss += loss_fn(logits, labels).item()

            correct += (logits.argmax(1) == labels).type(torch.float).sum().item()

            probs = F.softmax(logits, dim=1)
            predictions = torch.argmax(logits, dim=1)
            all_predictions.extend(torch.argmax(probs, dim=1).cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    test_loss /= num_batches
    correct /= size

    f1 = f1_score(all_labels, all_predictions, average='weighted')
    f1_macro = f1_score(all_labels, all_predictions, average='macro')
    f1_micro = f1_score(all_labels, all_predictions, average='micro')
    f1_classwise = f1_score(all_labels, all_predictions, average=None)
    matthews = matthews_corrcoef(all_labels, all_predictions)
    acc = accuracy_score(all_labels, all_predictions)

    print(
        f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )
    print("acc", acc)
    print("matthews_corrcoef", matthews)
    print("f1", f1)
    print("f1_macro", f1_macro)
    print("f1_micro", f1_micro)
    print("f1_classwise", f1_classwise)

and the rest:

classifier = WeightedAverageTransformerMLP(classifier_name, 30, 4).to(device)

loss_fn = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(classifier.parameters())

Check if gradients are properly computed by iterating all parameters of the model and printing their .grad attribute. This attribute should show valid gradient values after the backward call. If that’s not the case and you see None values it would indicate no gradients are computed and the computation graph might be detached.

Firstly thank you for your answer.The weights of the transformer are None but for the layer weights (weighted average) and the mlp weights they are not None and changing.
image