Hello,
I have trained a number of models with different hyper-parameters. I save the state dict for each model at every 10 epochs, and load them in a separate script for inference. However, the saved models output the exact same probability distribution for all samples. To clarify, there are 3 possible classes, and the probability for class 0 in sample 1 is exactly the same as the probability for class 0 in all other samples. I have confirmed that the data has been loaded correctly, and that the models have been trained correctly via plotting the loss curve. I am completely puzzled, and desperately need help diagnosing the issue. I have included my inference code below. Any and all help would be greatly appreciated. Thanks!
import torch
import pandas as pd
import os
import torch.nn as nn
import torch.nn.functional as F
import argparse
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score, confusion_matrix
import numpy as np
from LA_model import Net, load_data
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str)
args = parser.parse_args()
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
df_train=pd.read_csv('train_set.csv')
df_test=pd.read_csv('test_set.csv')
model = Net()
model.load_state_dict(torch.load('saved_models/'+args.model_name+'.pt', weights_only=True))
model.to(device=device)
model.eval()
with torch.no_grad():
train_pred=[]
train_label=[]
test_pred=[]
test_label=[]
bctr=0
for i in range(len(df_train)//100):
cur_x, y = load_data(df_train, bctr, 100, device)
bctr+=1
print(model(cur_x))
y_hat = torch.argmax(model(cur_x),1)
for a, b in zip(y, y_hat):
train_pred.append(int(b))
train_label.append(int(a))
bctr=0
for i in range(len(df_test)//100):
cur_x, y = load_data(df_test, bctr, 100, device)
bctr+=1
y_hat = torch.argmax(model(cur_x),1)
for a, b in zip(y, y_hat):
test_pred.append(int(b))
test_label.append(int(a))
acc = balanced_accuracy_score(test_label, test_pred)
train_acc = balanced_accuracy_score(train_label, train_pred)