Overfitting issue in GNN on Colab

Hello everyone,

I have been trying to train a GNN using PyG for a multiclass classification problem with 4 classes. The dataset is small (400 samples) and imbalanced. The graphs represent biological networks and are instances of the class Data, with attributes x, edge_index, edge_attr, edge_weight, and y.
Each graph has approx. 900 nodes with two features each, and 5000 edges.
Order of magnitude of Data.x is approx. 1e2-1e3,while for edge_attr, edge weight is approx. 1e-2-1e-3.

I am running the code on Google Colab (GPU hardware).

After managing to overfit the model on a single batch, I trained and tested it but it overfitted. I tried several approaches, none of which has worked so far, so maybe there is something in the code that I am missing?

In particular, when using dropout=0 the model easily overfits (100% accuracy) but obviously learns nothing on the test set, while when using dropout=0.5 the model usually reaches around 70% accuracy but the performance on the test set is still around 25%. When applying dropout I also tried expanding the network and increasing the number of neurons in each layer, and I trained it for 600 epochs or more, which is a lot compared to the number of epoch it takes to overfit with no dropout (i.e. less than 300).

I have also tried changing the optimizer, the lr, the weight decay, adding a scheduler for the lr, using different convolutional layers and architectures (unfortunately there is nothing published on significantly similar problems), and changing the batchsize. I tried with and without data augmentation.

I tried to regroup the 4 classes into 2 (which is ok for this particular problem) out of curiosity as well, but achieved nothing.

I appreciate that it may be difficult to help, but maybe there is something that I am missing or that I should try (maybe in a more organized/structured fashion)?

# install the required packages
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

# make code reproducible
random_state = 1

# load dataset
from google.colab import files
uploaded = files.upload()
metabolic_dataset = torch.load("metabolic_graphs.pt")  # list of Data objects

# define transforms for data pre-processing and augmentation
from torch_geometric.transforms import BaseTransform

class GraphNormalize(BaseTransform):
    def __init__(self, norm="nodes"): # norm can be either "nodes", "edges" or "both"
        self.norm = norm

    def __call__(self, data):
        if self.norm == "nodes":
            data.x = data.x / torch.max(data.x, 0)[0]
        elif self.norm == "edges":
            data.edge_attr = data.edge_attr / torch.max(data.edge_attr, 0)[0]
            data.edge_weight = data.edge_weight / torch.max(data.edge_weight)[0]
            data.embeddings = data.embeddings / torch.max(data.embeddings, 0)[0]
            data.edge_attr = data.edge_attr / torch.max(data.edge_attr, 0)[0]
            data.edge_weight = data.edge_weight / torch.max(data.edge_weight)[0]
        return data

from torch_geometric.data import Data
from torch_geometric.utils.mask import index_to_mask
import random


class EdgeSampler(BaseTransform):
  def __init__(self, num_edges): # num_links can be either int or float
        self.num_edges = num_edges

  def __call__(self, data):
    # sample edges
    if isinstance(self.num_edges, int):
        sampled_links = random.sample(range(data.edge_index.shape[1]), self.num_edges)
    elif isinstance(self.num_edges, float):
        num_edges = int(self.num_edges * data.edge_index.shape[1])
        sampled_links = random.sample(range(data.edge_index.shape[1]), num_edges)
        raise NotImplementedError
    # select relevant nodes
    edge_weight = data.edge_weight[sampled_links]
    edge_attr = data.edge_attr[sampled_links]
    edge_index = data.edge_index[:, sampled_links]
    nodes = torch.unique(edge_index)
    x = data.x[nodes]
    # relabel nodes
    node_idx = torch.zeros(index_to_mask(nodes).size(0), dtype=torch.long)
    node_idx[nodes] = torch.arange(len(nodes))
    edge_index = node_idx[edge_index]
    return Data(x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr,  y=data.y)

# pre-processing and data augmentation
import torch_geometric.transforms as T

transforms1 = T.Compose([

transforms2 = T.Compose([

metabolic_dataset = torch.load("metabolic_graphs.pt") # I have to load it again to apply a new transform chain
dataset = [transforms1(data) for data in metabolic_dataset]

# as data augmentation tecnique, I sample the edges of the original graphs to obtain new slightly smaller graphs
metabolic_dataset2 = torch.load("metabolic_graphs.pt")
augmentations1 = [transforms2(data) for data in metabolic_dataset2]
metabolic_dataset3 = torch.load("metabolic_graphs.pt")
augmentations2 = [transforms2(data) for data in metabolic_dataset3]

# model definition
from torch.nn import Linear, ReLU, ModuleList
import torch.nn.functional as F
from torch_geometric.nn import DeepGCNLayer, GraphConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.norm import GraphNorm

class DeeperGCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        self.conv1 = GraphConv(2, hidden_channels)
        self.norm1 = GraphNorm(hidden_channels)

        self.layers = torch.nn.ModuleList()
        for i in range(1, num_layers + 1):
            conv = GraphConv(hidden_channels, hidden_channels)
            norm = GraphNorm(hidden_channels)
            act = ReLU(inplace=True)

            layer = DeepGCNLayer(conv, norm, act, block="plain", dropout=0.5)

        self.lin = Linear(hidden_channels, 2)

    def forward(self, x, edge_index, edge_weight, batch=None):

        x = self.conv1(x, edge_index, edge_weight)
        x = self.norm1(x)
        x = x.relu()
        for layer in self.layers:
            x = layer(x, edge_index, edge_weight)

        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        return self.lin(x)

# training and testing
from torch_geometric.loader import DataLoader
import numpy as np
from sklearn.model_selection import StratifiedKFold

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor([140/114, 1, 140/47, 140/98]).to(device)) # to balance the training

def train():
    running_loss = 0.0
    for data in train_loader:
        out = model(data.x, data.edge_index, data.edge_weight, batch=data.batch)
        loss = criterion(out, data.y)
        running_loss =+ loss.item()
    train_loss.append(running_loss / len(train_loader.dataset))

def test(loader, val_accuracy, val_loss):
    correct = 0
    running_loss = 0.0
    for data in loader:
        out = model(data.x, data.edge_index, data.edge_weight, batch=data.batch)
        loss = criterion(out, data.y)
        running_loss =+ loss.item()
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
    val_loss.append(running_loss / len(loader.dataset))
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.

y = []
for data in dataset:
y = torch.Tensor(y)

k = 0
for train_index, test_index in cv.split(dataset, y.cpu()):
  train_data, test_data = [dataset[x] for x in train_index], [dataset[x] for x in test_index]
  augmented_data = train_data + [augmentations1[x] for x in train_index] + [augmentations2[x] for x in train_index]
  train_loader = DataLoader(augmented_data, batch_size=32)
  test_loader = DataLoader(test_data, batch_size=32)
  train_loss = []
  val_loss = []
  train_accuracy = []
  val_accuracy = []

  model = DeeperGCN(hidden_channels=512, num_layers=8)
  optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-4)

  k = k + 1
  print("Split {}:".format(k))
  for epoch in range(1, 301):
      train_acc = test(train_loader, train_accuracy, [])
      test_acc = test(test_loader, val_accuracy, val_loss)
      print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\n')
      if np.isnan(train_loss).any():

Many thanks