GNN does not predict/generalize well on test set

Hi guys,

I am not generally an expert on neural networks, but I am trying to learn.

I am dealing with 5-dimensional data; each feature has 2D and 3D parts. I stack them together and normalize subtracting the mean and dividing by the standard deviation. I am trying to solve a problem of binary classification. I used GNN to classify each graph, which should include these 5D features.

I tried a few representations of the graphs:

  • Tree-like, a general root connects all the features (1 edge root to features, 1 edge feature to root)
  • Nearest Neighbor style, I connect the N features are closer in the 2d space; I tried N = 1, 5

I am using training, validation, and test set. The validation is sliced from the training set, and the test set instead from different data, but it should be pretty similar in terms of distribution.

Losses and accuracies are “good” for training and validation; however, I usually have trends of increasing loss and stable (low) accuracy in the test set. So basically, my network overfits.

To start, I am using this GNN sample, which should be pretty basic, taken from torch_geometric tutorial. The model is the same. I use a batch_size of 128 and Adam with lr=1e-4

I have tried multiple things, but nothing seems to show substantial improvements:

  • some data augmentation of the training set
  • try changing some hyperparameters, learning rate, optimizer, batch size
  • change values of noise in the training set

Attached are some plots of losses and accuracies over epochs, a super small grid search:
cyan) original model as in the link with 64 cells each
magenta) just first two layers of original model with 64 cells each
yellow) two layers with 32 cells each

I am using just the first tree-like graph representation, If I connect nodes based on their “vicinity” my network predictions are horrible.

I don’t know what to do next.
It is the data? It can be just a matter of grid search?

Probably I don’t have enough methodology. How would you start a deep learning project? How would you proceed? It is there a sort of best-practice guide that one can follow to make it work, isolating problems step-by-step?

Any help, suggestion, reference, will be very appreciated.