The code I am using to train my model is mentioned below, I’d appreciate any sort of help that could prevent my runtime from crashing, and train the model successfully.
import os
import pickle
import pandas as pd
from sklearn.model_selection import KFold
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss
import torch.nn as nn
from sklearn.metrics import precision_score, recall_score
import torch
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
interactions_df = pd.read_csv(‘/content/drive/MyDrive/PPI/Known_Interactions_Data.csv’)
graph_dir = “/content/drive/MyDrive/PPI/Protein_Graphs”
hyperparameters = {
‘dropout’: [0.7, 0.5, 0.2],
‘num_layers’: [1, 2],
‘classifier_hidden_dim’: [64, 128],
‘input_dim’: [64],
‘num_heads’: [1],
‘learning_rate’: [0.001, 0.01, 0.1],
‘num_epochs’: [10, 20, 30],
}
best_precision = 0.0
best_recall = 0.0
best_hyperparameters = None
for dropout in hyperparameters[‘dropout’]:
for num_layers in hyperparameters[‘num_layers’]:
for classifier_hidden_dim in hyperparameters[‘classifier_hidden_dim’]:
for input_dim in hyperparameters[‘input_dim’]:
for num_heads in hyperparameters[‘num_heads’]:
for learning_rate in hyperparameters[‘learning_rate’]:
for num_epochs in hyperparameters[‘num_epochs’]:
model = GCN_Model(num_features_pro=4,
classifier_hidden_dim=classifier_hidden_dim,
input_dim=input_dim,
num_layers=num_layers,
dropout=dropout)
model.to(device)
optimizer = Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()
kfold = KFold(n_splits=3, shuffle=True)
precision_scores = []
recall_scores = []
for fold, (train_indices, val_indices) in enumerate(kfold.split(interactions_df)):
train_data = interactions_df.iloc[train_indices]
val_data = interactions_df.iloc[val_indices]
# Training loop
model.train()
for epoch in range(num_epochs):
running_loss = 0.0
batch_size = 10
num_batches = len(train_indices) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = (batch_idx + 1) * batch_size
batch_data = train_data.iloc[start_idx:end_idx]
batch_loss = 0.0
for idx, row in batch_data.iterrows():
protein_A = row.iloc[1]
protein_B = row.iloc[2]
interaction_class = row.iloc[3]
interaction_class = 1 if interaction_class == 'Yes' else 0
filenameA = os.path.join(graph_dir, f'{protein_A}_pg.gpickle')
filenameB = os.path.join(graph_dir, f'{protein_B}_pg.gpickle')
with open(filenameA, 'rb') as f:
graph_A = pickle.load(f)
with open(filenameB, 'rb') as f:
graph_B = pickle.load(f)
x_A, edge_index_A, edge_features_A = extract_graph_features(graph_A)
x_B, edge_index_B, edge_features_B = extract_graph_features(graph_B)
x_A = x_A.to(device)
edge_index_A = edge_index_A.to(device)
edge_features_A = edge_features_A.to(device)
x_B = x_B.to(device)
edge_index_B = edge_index_B.to(device)
edge_features_B = edge_features_B.to(device)
graphA = (x_A, edge_index_A, edge_features_A)
graphB = (x_B, edge_index_B, edge_features_B)
labels = torch.tensor([interaction_class]).float().to(device)
optimizer.zero_grad()
outputs = model(graphA, graphB)
outputs = outputs.squeeze()
labels = labels.squeeze()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
batch_loss += loss.item()
running_loss += batch_loss / batch_size
epoch_loss = running_loss / num_batches
print(f'Fold {fold + 1}, Epoch {epoch + 1}, Loss: {epoch_loss:.4f}')
# Validation loop
model.eval()
val_loss = 0.0
correct = 0
total = 0
predicted_labels = []
true_labels = []
with torch.no_grad():
for idx in val_indices:
row = val_data.iloc[idx]
protein_A = row.iloc[0]
protein_B = row.iloc[1]
interaction_class = row.iloc[2]
interaction_class = 1 if interaction_class == 'Yes' else 0
filenameA = os.path.join(graph_dir, f'{protein_A}_pg.gpickle')
filenameB = os.path.join(graph_dir, f'{protein_B}_pg.gpickle')
with open(filenameA, 'rb') as f:
graph_A = pickle.load(f)
with open(filenameB, 'rb') as f:
graph_B = pickle.load(f)
x_A, edge_index_A, edge_features_A = extract_graph_features(graph_A)
x_B, edge_index_B, edge_features_B = extract_graph_features(graph_B)
x_A = x_A.to(device)
edge_index_A = edge_index_A.to(device)
edge_features_A = edge_features_A.to(device)
x_B = x_B.to(device)
edge_index_B = edge_index_B.to(device)
edge_features_B = edge_features_B.to(device)
graphA = (x_A, edge_index_A, edge_features_A)
graphB = (x_B, edge_index_B, edge_features_B)
labels = torch.tensor([interaction_class]).float().to(device)
outputs = model(graphA, graphB)
outputs = outputs.squeeze()
labels = labels.squeeze()
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
predicted_labels.extend(predicted.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
val_loss /= len(val_indices)
accuracy = correct / total
print(f'Fold {fold + 1}, Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}')
precision = precision_score(true_labels, predicted_labels)
recall = recall_score(true_labels, predicted_labels)
precision_scores.append(precision)
recall_scores.append(recall)
avg_precision = sum(precision_scores) / len(precision_scores)
avg_recall = sum(recall_scores) / len(recall_scores)
if avg_precision > best_precision and avg_recall > best_recall:
best_precision = avg_precision
best_recall = avg_recall
best_hyperparameters = {
'dropout': dropout,
'num_layers': num_layers,
'classifier_hidden_dim': classifier_hidden_dim,
'output_dim': input_dim,
'num_heads': num_heads,
'learning_rate': learning_rate,
'num_epochs': num_epochs}
print(f"Best Precision: {best_precision}, Best Recall: {best_recall}")
print(“Best Hyperparameters:”, best_hyperparameters)
I tried using tools like RandomizedSearchCV as well to perform the hyperparameter search but because of how every data instance has a different shape, I was running into a lot of error. I tried using the DataLoader as well, but again faced a shape mismatch issue. I finally decided to go ahead with manually iterating through my dataset and that is now causing my runtime to crash. Tried using Google Colab’s GPU as well, but the runtime crashed again, with the message, “used up all the available RAM”.
Any help is highly appreciated.