Model Built on Pre-Trained Embeddings Does Not Learn

Hello,

I am trying to build a small classification network on top of embeddings generated from a pre-trained model. The embeddings are pre-computed, then loaded in batches during training. The architecture consists of an attention layer followed by an MLP with several hidden layers before readout into the three possible categories. However, despite testing a wide variety of hyperparameters and different architectures, the model only guesses on both the holdout set and the set it was trained on. These results are surprising, and it seems likely that I’ve made a severe error. Is there something I missed in my code? Any advice would be appreciated. My code is included below.

import torch
import pandas as pd
import os
import torch.nn as nn
import torch.nn.functional as F
import argparse
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score, confusion_matrix
import numpy as np
class Net(nn.Module): #Classifier
    def __init__(self):
        super().__init__()
        kernel_size = 9
        conv_dropout = 0.25
        self.feature_convolution = nn.Conv1d(194, 194, kernel_size, stride=1, padding=kernel_size // 2)
        self.attention_convolution = nn.Conv1d(194, 194, kernel_size, stride=1, padding=kernel_size // 2)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(conv_dropout)
        self.dropout2 = nn.Dropout(0.1)
        self.dropout3 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(256, 100)
        self.fc2 = nn.Linear(100, 10)
        self.fc3 = nn.Linear(1940, 1000)
        self.fc4 = nn.Linear(1000, 100)
        self.fc5 = nn.Linear(100, 3)
    def forward(self, x):
        x = self.feature_convolution(x)
        x = self.dropout(x)
        attention = self.attention_convolution(x)
        x = x * self.softmax(attention)
        x = torch.squeeze(x) #Finish Attention Layer
        x = F.relu(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = torch.flatten(x, 1)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.relu(x)
        x = self.fc4(x) 
        x = self.dropout3(x)
        x = self.fc5(x) #Finish MLP
        x = F.softmax(x, dim=1) #Readout
        return x
def load_data(dataframe, cur_batch, bs):
    Y_temp = []
    X_temp = []
    mydict = {'like':0, 'worse':1, 'NB':2}
    for i in range((cur_batch*bs),(cur_batch*bs)+bs):
        Y_temp.append(mydict[dataframe['class'][i]])
        data = torch.load('embeddings_r1/'+dataframe['name'][i]+'_'+dataframe['class'][i]+'.pt', map_location=device)
        X_temp.append(data[597:,:])
    X = torch.stack(X_temp).to(device=device)
    Y = torch.tensor(Y_temp).to(device=device)
    return X, Y
parser = argparse.ArgumentParser()

parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-2)
args = parser.parse_args()
model_name='embed_model_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net()
df_train = pd.read_csv('train_set_oversample.csv')
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
Y=df_train.loc[:,'class']
mydict = {'like':0, 'worse':1, 'NB':2}
Y_hot = np.array([mydict[a] for a in Y])
uni, freq = np.unique(Y_hot, return_counts=True)
print(freq)
criterion = nn.CrossEntropyLoss()

model.to(device=device)
model.train()
for epoch in range(args.epochs):
    df_train_temp = df_train.sample(frac=1).values
    df_train_new = pd.DataFrame(df_train_temp, columns=['name','class']) #Shuffle training set every epoch
    bctr=0
    for i in range(len(df_train)//args.batch_size):
        optimizer.zero_grad()
        cur_x, cur_y = load_data(df_train_new, bctr, args.batch_size) #load current batch
        bctr+=1
        cur_out = model(cur_x)
        loss = criterion(cur_out, cur_y)
        print(loss)
        loss.backward()
        optimizer.step()
    if (epoch+1)%10 == 0:
        torch.save(model,'saved_models/'+model_name+'_cur_epoch_'+str(epoch+1)+'.pt')