Hello everyone,

I have been trying to train a GNN using PyG for a multiclass classification problem with 4 classes. The dataset is small (400 samples) and imbalanced. The graphs represent biological networks and are instances of the class Data, with attributes x, edge_index, edge_attr, edge_weight, and y.

Each graph has approx. 900 nodes with two features each, and 5000 edges.

Order of magnitude of Data.x is approx. 1e2-1e3,while for edge_attr, edge weight is approx. 1e-2-1e-3.

I am running the code on Google Colab (GPU hardware).

After managing to overfit the model on a single batch, I trained and tested it but it overfitted. I tried several approaches, none of which has worked so far, so maybe there is something in the code that I am missing?

In particular, when using dropout=0 the model easily overfits (100% accuracy) but obviously learns nothing on the test set, while when using dropout=0.5 the model usually reaches around 70% accuracy but the performance on the test set is still around 25%. When applying dropout I also tried expanding the network and increasing the number of neurons in each layer, and I trained it for 600 epochs or more, which is a lot compared to the number of epoch it takes to overfit with no dropout (i.e. less than 300).

I have also tried changing the optimizer, the lr, the weight decay, adding a scheduler for the lr, using different convolutional layers and architectures (unfortunately there is nothing published on significantly similar problems), and changing the batchsize. I tried with and without data augmentation.

I tried to regroup the 4 classes into 2 (which is ok for this particular problem) out of curiosity as well, but achieved nothing.

I appreciate that it may be difficult to help, but maybe there is something that I am missing or that I should try (maybe in a more organized/structured fashion)?

```
# install the required packages
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
# make code reproducible
random_state = 1
torch.manual_seed(random_state)
# load dataset
from google.colab import files
uploaded = files.upload()
metabolic_dataset = torch.load("metabolic_graphs.pt") # list of Data objects
# define transforms for data pre-processing and augmentation
from torch_geometric.transforms import BaseTransform
class GraphNormalize(BaseTransform):
def __init__(self, norm="nodes"): # norm can be either "nodes", "edges" or "both"
self.norm = norm
def __call__(self, data):
if self.norm == "nodes":
data.x = data.x / torch.max(data.x, 0)[0]
elif self.norm == "edges":
data.edge_attr = data.edge_attr / torch.max(data.edge_attr, 0)[0]
data.edge_weight = data.edge_weight / torch.max(data.edge_weight)[0]
else:
data.embeddings = data.embeddings / torch.max(data.embeddings, 0)[0]
data.edge_attr = data.edge_attr / torch.max(data.edge_attr, 0)[0]
data.edge_weight = data.edge_weight / torch.max(data.edge_weight)[0]
return data
from torch_geometric.data import Data
from torch_geometric.utils.mask import index_to_mask
import random
random.seed(random_state)
class EdgeSampler(BaseTransform):
def __init__(self, num_edges): # num_links can be either int or float
self.num_edges = num_edges
def __call__(self, data):
# sample edges
if isinstance(self.num_edges, int):
sampled_links = random.sample(range(data.edge_index.shape[1]), self.num_edges)
elif isinstance(self.num_edges, float):
num_edges = int(self.num_edges * data.edge_index.shape[1])
sampled_links = random.sample(range(data.edge_index.shape[1]), num_edges)
else:
raise NotImplementedError
# select relevant nodes
edge_weight = data.edge_weight[sampled_links]
edge_attr = data.edge_attr[sampled_links]
edge_index = data.edge_index[:, sampled_links]
nodes = torch.unique(edge_index)
x = data.x[nodes]
# relabel nodes
node_idx = torch.zeros(index_to_mask(nodes).size(0), dtype=torch.long)
node_idx[nodes] = torch.arange(len(nodes))
edge_index = node_idx[edge_index]
return Data(x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr, y=data.y)
# pre-processing and data augmentation
import torch_geometric.transforms as T
transforms1 = T.Compose([
GraphNormalize(norm="nodes"),
T.TargetIndegree(cat=True),
T.GCNNorm(),
T.ToDevice(device=device)
])
transforms2 = T.Compose([
EdgeSampler(num_edges=0.90),
GraphNormalize(norm="nodes"),
T.TargetIndegree(cat=True),
T.GCNNorm(),
T.ToDevice(device=device)
])
metabolic_dataset = torch.load("metabolic_graphs.pt") # I have to load it again to apply a new transform chain
dataset = [transforms1(data) for data in metabolic_dataset]
# as data augmentation tecnique, I sample the edges of the original graphs to obtain new slightly smaller graphs
metabolic_dataset2 = torch.load("metabolic_graphs.pt")
augmentations1 = [transforms2(data) for data in metabolic_dataset2]
metabolic_dataset3 = torch.load("metabolic_graphs.pt")
augmentations2 = [transforms2(data) for data in metabolic_dataset3]
# model definition
from torch.nn import Linear, ReLU, ModuleList
import torch.nn.functional as F
from torch_geometric.nn import DeepGCNLayer, GraphConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.norm import GraphNorm
class DeeperGCN(torch.nn.Module):
def __init__(self, hidden_channels, num_layers):
super().__init__()
self.conv1 = GraphConv(2, hidden_channels)
self.norm1 = GraphNorm(hidden_channels)
self.layers = torch.nn.ModuleList()
for i in range(1, num_layers + 1):
conv = GraphConv(hidden_channels, hidden_channels)
norm = GraphNorm(hidden_channels)
act = ReLU(inplace=True)
layer = DeepGCNLayer(conv, norm, act, block="plain", dropout=0.5)
self.layers.append(layer)
self.lin = Linear(hidden_channels, 2)
def forward(self, x, edge_index, edge_weight, batch=None):
x = self.conv1(x, edge_index, edge_weight)
x = self.norm1(x)
x = x.relu()
for layer in self.layers:
x = layer(x, edge_index, edge_weight)
x = global_mean_pool(x, batch)
x = F.dropout(x, p=0.5, training=self.training)
return self.lin(x)
# training and testing
from torch_geometric.loader import DataLoader
import numpy as np
from sklearn.model_selection import StratifiedKFold
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor([140/114, 1, 140/47, 140/98]).to(device)) # to balance the training
def train():
model.train()
running_loss = 0.0
for data in train_loader:
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.edge_weight, batch=data.batch)
loss = criterion(out, data.y)
running_loss =+ loss.item()
loss.backward()
optimizer.step()
train_loss.append(running_loss / len(train_loader.dataset))
def test(loader, val_accuracy, val_loss):
model.eval()
correct = 0
running_loss = 0.0
for data in loader:
out = model(data.x, data.edge_index, data.edge_weight, batch=data.batch)
loss = criterion(out, data.y)
running_loss =+ loss.item()
pred = out.argmax(dim=1) # Use the class with highest probability.
correct += int((pred == data.y).sum()) # Check against ground-truth labels.
val_accuracy.append(correct/len(loader.dataset))
val_loss.append(running_loss / len(loader.dataset))
return correct / len(loader.dataset) # Derive ratio of correct predictions.
y = []
for data in dataset:
y.append(data.y)
y = torch.Tensor(y)
k = 0
for train_index, test_index in cv.split(dataset, y.cpu()):
train_data, test_data = [dataset[x] for x in train_index], [dataset[x] for x in test_index]
augmented_data = train_data + [augmentations1[x] for x in train_index] + [augmentations2[x] for x in train_index]
train_loader = DataLoader(augmented_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
train_loss = []
val_loss = []
train_accuracy = []
val_accuracy = []
model = DeeperGCN(hidden_channels=512, num_layers=8)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-4)
model.to(device)
k = k + 1
print("Split {}:".format(k))
for epoch in range(1, 301):
train()
train_acc = test(train_loader, train_accuracy, [])
test_acc = test(test_loader, val_accuracy, val_loss)
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\n')
if np.isnan(train_loss).any():
break
```

Many thanks