Model Accuracy Is Almost Zero After Reloading

Hi there

I trained a model for Named Entity Recognition (NER). I got an accuracy of about 80% after 40 epochs. Validation gave about 78% accuracy. However, when I load the model after saving, I get an accuracy of almost zero. I thought this was a randomisation issue so I seeded numpy randomizer and seeded torch as well. Still the same error.

Here’s the model definition:

class NERModel(nn.Module):
  def __init__(self):
    super(NERModel,self).__init__()
    self.base_count = 128
    self.relu = nn.ReLU()
    self.sigmoid = nn.Sigmoid()
    self.BatchNorm = nn.BatchNorm1d(num_features=100,)#100 is just a placeholder which will be updated
    self.SpanEmbed = nn.Embedding(num_embeddings=len(char2idx),embedding_dim=EMBEDDING_DIM,_weight = torch.from_numpy(embed_matrix)).type(torch.float)#,_freeze = True
    self.SentEmbed = nn.Embedding(num_embeddings=len(char2idx),embedding_dim=EMBEDDING_DIM,_weight = torch.from_numpy(embed_matrix)).type(torch.float)#,_freeze = True

    self.Span_Layer_Stack_PreS = nn.Sequential(
        #nn.Embedding(num_embeddings=len(char2idx),embedding_dim=EMBEDDING_DIM,_weight = torch.from_numpy(embed_matrix),_freeze = True).type(torch.float),
        self.SpanEmbed,
        nn.LSTM(input_size=EMBEDDING_DIM,hidden_size=1800,bidirectional=True,batch_first=True).type(torch.float)
        )

    self.Span_Layer_Stack_PostS = nn.Sequential(
        # nn.BatchNorm1d(num_features=MAX_SPAN_LEN),#100 is just a placeholder which will be updated
        nn.Dropout(0.5),
        nn.Linear(in_features = 3600,out_features = 1920),
        nn.ReLU(),
        nn.Linear(in_features = 1920,out_features = 640),
        nn.ReLU(),
        nn.Linear(in_features = 640,out_features = 1280)

    )

    self.Sent_Layer_Stack_PreS = nn.Sequential(
        # nn.Embedding(num_embeddings=len(char2idx),embedding_dim=EMBEDDING_DIM,_weight = torch.from_numpy(embed_matrix),_freeze = True).type(torch.float),
        self.SentEmbed,
        nn.LSTM(input_size=EMBEDDING_DIM,hidden_size=1800,bidirectional=True,batch_first = True).type(torch.float)
        )

    self.Sent_Layer_Stack_PostS = nn.Sequential(
        nn.BatchNorm1d(num_features=3600),#100 is just a placeholder which will be updated
        nn.Dropout(0.5),
        nn.Linear(in_features = 3600,out_features = 1920),
        nn.ReLU(),
        nn.Linear(in_features = 1920,out_features = 640),
        nn.ReLU(),
        nn.Linear(in_features = 640,out_features = 1280)

    )

    self.Comb_Layer_Stack = nn.Sequential(
        nn.Linear(in_features = 2560,out_features = 2048),
        nn.ReLU(),
        nn.Linear(in_features = 2048,out_features = 1024),
        nn.ReLU(),
        nn.Linear(in_features = 1024,out_features = 512),
        nn.ReLU(),
        nn.Linear(in_features = 512,out_features = 512),
        nn.ReLU(),
        nn.Linear(in_features = 512,out_features = 256),
        nn.ReLU(),
        nn.Linear(in_features = 256,out_features = 128),
        nn.ReLU(),
        nn.Linear(in_features = 128,out_features = len(label_idx)),
        nn.Sigmoid()
    )


  def forward(self,x:torch.Tensor) -> torch.Tensor:

    span_input,sent_input = torch.split(x,[int(x.shape[1]/2),int(x.shape[1]/2)],dim = 1)
    span_out_tot,(span_out,span_c_state_out) = self.Span_Layer_Stack_PreS(span_input.type(torch.long))
    sent_out_tot,(sent_out,sent_c_state_out) = self.Sent_Layer_Stack_PreS(sent_input.type(torch.long))

    #stack the layers
    span_out = torch.hstack([span_out[0],span_out[1]])
    sent_out = torch.hstack([sent_out[0],sent_out[1]])

    #run the stacked layers through the post stack layers
    span_out = self.Span_Layer_Stack_PostS(span_out.type(torch.float))
    sent_out = self.Sent_Layer_Stack_PostS(sent_out.type(torch.float))

    #combine the layers
    comb_layers = torch.hstack([span_out,sent_out])

    span_label = self.Comb_Layer_Stack(comb_layers)

    return(span_label)

The train (including checkpoint saving), val and test code are as follows:

def train_model(model:NERModel,num_epochs = 10,lrate = 0.005, save_freq = 2):

  torch.manual_seed(42)

  # m_accuracy = Accuracy(task="multiclass", num_classes=len(label_idx)).to(device)

  loss_fn = nn.BCELoss()
  opt = torch.optim.Adam(params=[param for param in model.parameters() if param.requires_grad == True],lr=lrate)

  epochs = num_epochs

  scheduler = lr_scheduler.ReduceLROnPlateau(opt,mode='min',factor=0.8,patience=2,verbose=True)

  mb = master_bar(range(epochs))

  for epoch in mb:#range(epochs):#tqdm(range(epochs),total = epochs):

    train_loss = 0
    train_acc = 0

    model.train()
    print(f'Epoch {epoch + 1}')
    total_data_processed = 0
    start_time = timer.perf_counter()

    for batch, (X,y) in progress_bar(enumerate(train_dataloader),total = len(train_dataloader),parent=mb):#enumerate(train_dataloader):#tqdm(enumerate(train_dataloader),total = len(train_dataloader)):

      y_pred = model(X).to(device)

      loss = loss_fn(y_pred.type(torch.float),y.type(torch.float))#torch.argmax(y_pred,dim=1).type(torch.float),y.type(torch.float))

      train_loss+=loss

      opt.zero_grad()

      loss.backward()

      opt.step()

      train_acc+= get_accuracy(torch.where(y_pred>0.5,1.0,0.0),y)#torch.argmax(y_pred,dim=1),y)

      total_data_processed += len(X)

    model.eval()

    with torch.inference_mode():
      total_val_loss, total_val_acc = 0, 0
      for X_val,y_val in val_dataloader:
        y_val_pred = model(X_val)
        total_val_loss += loss_fn(y_val_pred.type(torch.float),y_val.type(torch.float))
        total_val_acc += get_accuracy(torch.where(y_val_pred>0.5,1.0,0.0),y_val)
      val_loss = total_val_loss/len(val_dataloader)
      val_acc = total_val_acc/len(val_dataloader)
      scheduler.step(val_loss)
      print('Avg. Training Loss:{:.4f} | Val Loss:{:.4f} | Avg. Training Accuracy:{:.4f} | Val Accuracy:{:.4f}'\
      .format(train_loss/len(train_dataloader),val_loss,train_acc/len(train_dataloader),val_acc))

    end_time = timer.perf_counter()
    print('Epoch {:2d} | Elapsed Time:{:.3f}s'.format(epoch + 1, end_time - start_time))
    mb.write('Epoch {:2d} completed | Elapsed Time:{:.3f}s'.format(epoch + 1, end_time - start_time))

    if (epoch + 1) % save_freq == 0 and epoch + 1 > 0:
      filepath = config.environ_path[config.environ]['save'] + "Char_Tok_Models/Models/cp{:02d}.pth".format(epoch + 1)
      print('Saving model to:',filepath)
      torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'loss': val_loss,
              'accuracy': val_acc,
              }, filepath)

      print('Save completed')

def test_model(model_file:str,d_loader:DataLoader):

  torch.manual_seed(42)

  # m_accuracy = Accuracy(task="multiclass", num_classes=len(label_idx)).to(device)

  test_model = NERModel().to(device)

  loss_fn = nn.BCELoss()

  checkpoint = torch.load(config.environ_path[config.environ]['save'] + "Char_Tok_Models/Models/" + model_file)

  test_model.load_state_dict(checkpoint['model_state_dict'])

  test_model.eval()

  with torch.inference_mode():
    total_test_loss, total_test_acc = 0, 0
    for batch,(X_test,y_test) in progress_bar(enumerate(d_loader),total=len(d_loader)):
      y_test_pred = test_model(X_test)
      total_test_loss += loss_fn(y_test_pred.type(torch.float),y_test.type(torch.float))
      total_test_acc += get_accuracy(torch.where(y_test_pred>0.5,1.0,0.0),y_test)
    test_loss = total_test_loss/len(test_dataloader)
    test_acc = total_test_acc/len(test_dataloader)
    # scheduler.step(test_loss)
    print('Test Loss:{:.4f} | Test Accuracy:{:.4f}'.format(test_loss,test_acc))

The code for the accuracy metric is as follows:

def get_accuracy(y_pred,y_true):
  y_pred_arr = y_pred.cpu().numpy();y_true_arr = y_true.cpu().numpy()
  comp_out = np.where(np.sum(np.equal(y_pred_arr,y_true_arr),axis=1)<y_true_arr.shape[1],0,1)
  return np.sum(comp_out)/comp_out.shape[0]

Please let me know what I’m doing wrong. Thanks.

Solved. The labels codes were being generated afresh every time I restarted the kernel and each time with new numbers assigned to each label. Solved by hard-coding the labels. Is there a more elegant solution? Please let me know.