Transfer learning implementation problem

I am new at coding (trying to learn some for my masters). I need to create a neural network with pytorch for an assigment that I will pretrain on one dataset (I have 3) and finetune/test on another. Ι wrote some code which seemed to worked. I pretrained on dataset number 1 and finetuned on dataset number 2 and got decent results (83% balanced accuracy) but when I tried to change the combination of datasets (e.g. train on dataset 2 and finetune/test on dataset 3) the perfomance was at chance. I am not sure if something is wrong with my code or is this cause by the dataset.

I tried different hyperparameters thinking that might be the problem but nothing. The performanced stayed at chance in all combinations other than preatrain on dataset 1 and finetune/test on dataset 2.

I would really appreciate if you could take a quick look at my code and share your opinion. Is there something wrong with my code? I really can’t figure it out.

class MyNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, lr=0.001, fn_lr = 0.001):
        super(MyNetwork, self).__init__()
        self.lr = lr
        self.fn_lr = fn_lr
        # TO DO:TRY BATCH NORM AND OTHER RELU VARIANTS
        # TO DO: TRY DROPOUT RATES
        self.linear_relu_stack_main = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.LeakyReLU(negative_slope=0.02),

            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(negative_slope=0.03),
            #nn.Dropout(0.3),

            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(negative_slope=0.03),
            #nn.Dropout(0.3),
        )
        self.linear_relu_stack_output = nn.Sequential(
        
            nn.Linear(hidden_size, output_size)
             
        )

    def forward(self, x):
        vec = self.linear_relu_stack_main(x)
        logits = self.linear_relu_stack_output(vec)
        
        return logits

    def optimize(self, train_dataloader, val_dataloader = None, 
                 threshold=0.5, epochs=10, pretrain=True):
        
        
        loss_function = nn.BCEWithLogitsLoss()

        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        
        if pretrain:
            print("\n=== Pretraining ===\n")
        
            for epoch in range(epochs):
                mean_loss_per_epoch = 0
                
                self.train()
                for features, labels in train_dataloader:
                    optimizer.zero_grad()
                    predictions = self(features)
                    batch_loss = loss_function(predictions, labels)
                    batch_loss.backward()
                    optimizer.step()

                    mean_loss_per_epoch += batch_loss.item()

                mean_train_loss = mean_loss_per_epoch / len(train_dataloader)

                print(f"Epoch {epoch + 1}/{epochs}")
                print(f"Mean Pretraining Loss: {mean_train_loss}")
                
                
                
                self.eval()
                
                with torch.no_grad():
                    for features, labels in val_dataloader:
                        val_predictions = self(features)
                        scores = self.calculate_metrics(val_predictions, labels, threshold=0.5)
                        print(f"Balanced Accuracy:{scores['balanced_accuracy']}")
                        
        
        else:
            print("\n=== Finetuning ===\n")
            
            self.linear_relu_stack_main.requires_grad_(False)
    
                    
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                self.parameters()), 
                                         lr=self.fn_lr)
            
            
            
            for name, parameter in self.named_parameters():  
                if parameter.requires_grad: 
                    print(name)

                

            for epoch in range(epochs):
                self.train()
                mean_loss_per_epoch = 0

                for features, labels in train_dataloader:
                    optimizer.zero_grad()
                    predictions = self(features)
                    batch_loss = loss_function(predictions, labels)
                    batch_loss.backward()
                    optimizer.step()

                    mean_loss_per_epoch += batch_loss.item()

                mean_train_loss = mean_loss_per_epoch / len(train_dataloader)

                print(f"Epoch {epoch + 1}/{epochs}")
                print(f"Mean Finetuning Loss: {mean_train_loss}")

    def test(self, test_dataloader, threshold):
        self.eval()
        predictions = []
        labels = []

        with torch.no_grad():
            for test_X, test_y in test_dataloader:
                test_pred = self(test_X)
                test_pred = torch.sigmoid(test_pred)
                
                predictions.extend(test_pred.numpy())
                labels.extend(test_y.numpy())

        predictions = torch.tensor(predictions).squeeze()
        labels = torch.tensor(labels).squeeze()

        metrics = self.calculate_metrics(predictions, labels, threshold)
        for metric, score in metrics.items():
            
            print(f"{metric}: {round(score, 3)}")
       
       
  
    def calculate_metrics(self, predictions, labels, threshold):
        
        predicted_classes = (torch.sigmoid(predictions) > threshold).numpy()
        labels_np = labels.numpy()

        metrics = {
            'accuracy': accuracy_score(labels_np, predicted_classes),
            'precision': precision_score(labels_np, predicted_classes),
            'recall': recall_score(labels_np, predicted_classes),
            'f1': f1_score(labels_np, predicted_classes),
            'balanced_accuracy': balanced_accuracy_score(labels_np, predicted_classes),
            'mcc': matthews_corrcoef(labels_np, predicted_classes),
        }

        return metrics

def main():
    torch.manual_seed(42)
    it_df = pd.read_csv('..')
    cz_df = pd.read_csv('...')
    sp_df = pd.read_csv('...')
    
    
    #nn parameters
    thresh= 0.5
    hidden= 32
    tr_epochs = 20
    fn_epochs= 5
    tr_batch_size= 32
    fn_batch_size= 32
    learning_rate= 0.01
    fineting_lr= 0.001

     #datasets
    pretrain_df =  it_df.copy()
    fine_tuning_df = sp_df.copy()
    
    
    
    pretrain_df = drop_empty(pretrain_df)
    fine_tuning_df = drop_empty(fine_tuning_df)
    fine_tuning_df = fine_tuning_df[pretrain_df.columns.tolist()]
    
    
    
    pretrain_features, pretrain_labels = define_features_labels(pretrain_df, label_column='status')
    
    x_pretrain, x_val, y_pretrain, y_val = train_test_split(
        pretrain_features, pretrain_labels, test_size=0.2, random_state=42, stratify=pretrain_labels
    )
    

    finetune_features, finetune_labels = define_features_labels(fine_tuning_df, label_column='status')

    x_finetune, x_test, y_finetune, y_test = train_test_split(
        finetune_features, finetune_labels, test_size=0.2, random_state=42, stratify=finetune_labels
    )
    
  
   
    pretrain_dataset = CustomDataset(x_pretrain, y_pretrain)
    pretrain_loader = DataLoader(pretrain_dataset, batch_size=tr_batch_size, shuffle=True)
    
    val_dataset = CustomDataset(x_val, y_val)
    val_loader = DataLoader(val_dataset, batch_size=tr_batch_size, shuffle=True)
    
   
  
    input_size = x_pretrain.shape[1]
    hidden_size = hidden
    output_size = 1
    model = MyNetwork(input_size, hidden_size, output_size, lr=learning_rate, fn_lr= fineting_lr)
    
    
    model.optimize(pretrain_loader, val_loader,  pretrain= True, epochs=tr_epochs)

    
    
    finetune_dataset = CustomDataset(x_finetune, y_finetune)
    test_dataset = CustomDataset(x_test, y_test)
    finetune_loader = DataLoader(finetune_dataset, batch_size=fn_batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
   

    print("Fine-tuning the model...")
    model.optimize(finetune_loader, pretrain = False, epochs=fn_epochs )  
    model.test(test_loader, threshold = thresh)


if __name__ == '__main__':
    main()