Hi there,
I’m trying to use a NN for a classification into two classes. As this did not work with my dataset (constant prediction for each batch) I wrote a simpler version of the code, but can’t still find the problem.
Here’s a minimal version code:
class Model(nn.Module):
def __init__(self, input_size, hidden_sizes_fc=[100, 2]):
super().__init__()
self.fc_list = nn.ModuleList([nn.Linear(input_size, hidden_sizes_fc[0])])
for hidden_size_fc_ind in range(0, len(hidden_sizes_fc)-1):
self.fc_list.append(nn.Linear(hidden_sizes_fc[hidden_size_fc_ind],
hidden_sizes_fc[hidden_size_fc_ind+1]))
def forward(self, x):
relu = nn.ReLU()
for i, FC in enumerate(self.fc_list):
x = FC(x)
x = relu(x)
return x
def train_std_nn(net, train, val, epochs, loss_fn):
optimiser = torch.optim.Adam(net.parameters(), lr=0.0001)
train_losses_epochs = []
val_score_epochs = []
net.train()
for epoch in trange(epochs):
train_loss = 0.0
total_computations = 0
for X, Y in train:
output = net(X)
loss = loss_fn(output, Y)
loss.backward()
optimiser.step()
train_loss += loss.item()
total_computations += Y.shape[0]
train_losses_epochs.append(train_loss / total_computations)
for X_val, Y_val in val:
output = net(X_val)
top_p, top_class = torch.topk(output, 1, dim=1)
pred = torch.flatten(top_class).detach().numpy()
val_score_epochs.append(roc_auc_score(Y_val.numpy(), pred))
return net, train_losses_epochs, val_score_epochs
epochs = 10
batch_size = 128
hidden_layers_size = [16, 2]
net = Model(input_size=11, hidden_sizes_fc=hidden_layers_size).double()
loss_fn = nn.CrossEntropyLoss()
aaa = torch.Tensor(np.random.rand(15, 11)).double()#.type(torch.LongTensor)
bbb = torch.Tensor(np.random.randint(0, 2, (15))).type(torch.LongTensor)
net, train_losses_epochs, val_score_epochs = train_std_nn(net, [[aaa, bbb]], [[aaa, bbb]], epochs, loss_fn)
I’ve plotted some graphs of the training loss and validation score (area under the curve). But the model doesn’t seem to learn anything… Training loss does random stuff (mainly decreasing but depends on the run) and auc is always 0.5
Thanks for help!