Model.eval() not causing GraphNorm to use running stats

Hey!
I’ve been having this weird problem recently that I am unable to fix.
So essentially the problem is that when I use model.eval(), I believe what we expect is that the GraphNorm layers in a model use the running stats to normalise the inputs. This would mean that it doesn’t matter how large your batchsize is as the GraphNorm layer doesn’t use the stats of the batch, but of the running stats from the training mode.
However, i’ve been playing around with the outputs of my model on the validation set, and I see that the output of the model depends on the size of the batch. And this happens because the GraphNorm layers don’t use the running stats, but the stats of the batch itself, so the normalisation of the graph depends on the size of the batch used.

The model that i’m using is shown below:

import torch
import torch_geometric.nn
from torch_geometric.nn import GraphNorm, GraphConv


class GraphConvLayer(torch.nn.Module):
    def __init__(self, prev_shape, output_shape):
        super(GraphConvLayer, self).__init__()
        self.conv = GraphConv(prev_shape, output_shape)
        self.norm = GraphNorm(output_shape)
        self.relu = torch.nn.LeakyReLU()
    def forward(self, x, edges, batch=None):
        x = self.conv(x, edges)
        x = self.norm(x)
        x = self.relu(x)
        return x 


# Define model
class MyModel(torch.nn.Module):
    def __init__(self, settings):
        super(MyModel, self).__init__()
        # torch.manual_seed(12345)
        self.input_bn = GraphNorm(settings['input_features'])
        
        # Define the convolution layers
        self.conv_process = torch.nn.ModuleList()
        previous_output_shape = settings['input_features']
        for layer_param in settings['conv_params']:
            # Hidden channels
            H = layer_param
            self.conv_process.append(GraphConvLayer(previous_output_shape, H))
            previous_output_shape = H
        
        # Define the fully connected layers 
        self.fc_process = torch.nn.ModuleList()
        for layer_param in settings['fc_params']:
            drop_rate, H = layer_param
            seq = torch.nn.Sequential(
                torch.nn.Linear(previous_output_shape, H),
                torch.nn.Dropout(p=drop_rate),
                torch.nn.LeakyReLU()
            )
            self.fc_process.append(seq)
            previous_output_shape = H

        # Final output layer
        self.output_mlp_linear = torch.nn.Linear(previous_output_shape, settings['output_classes'])
        

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        print(f'Before GraphNorm: {x[:2]}')
        x = self.input_bn(x)
        print(f'After GraphNorm: {x[:2]}')

        for layer in self.conv_process:
            x = layer(x, edge_index)

        x = torch_geometric.nn.global_mean_pool(x, batch)

        for layer in self.fc_process:
            x = layer(x)

        x = self.output_mlp_linear(x)
 
        return x

The main training loop that I use is shown here (with some stuff taken out for saving the model etc):

import torch
from preprocessing.inMemoryDatasetClass import InMemDataset
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import pickle
from torch.optim import lr_scheduler
from models.GraphConv import MyModel


# -----------------------------Define Model and Hyperparams-------------------------------
dataset_folder = "data/"
batch_size = 500

# Define the model settings
model_name = 'GraphConv'
settings = {
    'input_features' : 5,
    'conv_params' : [32, 32, 64, 64, 64, 128, 128, 128],
    'fc_params' : [(0.4, 128)],
    'output_classes' : 2
}


# Define optimizer settings
lr = 0.01
min_lr = None
max_lr = None
step = 25
gamma = 0.5
decay = 0.00001
epochs = 200
print(f"Number of epochs = {epochs}")




# -----------------------------Create model and set optimizer-----------------------------
model = MyModel(settings)
print(model)
print(f"Number of parameters in model = {sum([param.nelement() for param in model.parameters()])}")
# Now to trasnfer to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
#print(f"GPU = {torch.cuda.get_device_name(0)}")
model.to(device)


# Setup optimizer
if lr == None:
    optimizer = torch.optim.Adam(model.parameters(), lr=min_lr, weight_decay=decay)
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)
criterion = torch.nn.CrossEntropyLoss()
scheduler = lr_scheduler.StepLR(optimizer, step_size=step, gamma=gamma)
#scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=min_lr, max_lr=max_lr, cycle_momentum=False)



# -----------------------------Create train/validation datasets-----------------------------

# First let's read in the data into the MyDataset class 
dataset = InMemDataset(f"{dataset_folder}")
print(f'Number of samples = {len(dataset)}')
print(f"Number of features = {dataset.num_features}")

# Split into training and validation sets but always use the same seed so that the split 
# is the same every time.
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

# Now put them into data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,pin_memory=True, num_workers = 8)
validation_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,pin_memory=True, num_workers = 8)



# -----------------------------Define training and validation functions-----------------------------
def train():
    model.train()
    print("Training:")
    for data in train_loader:  # Iterate in batches over the training dataset.
        
        optimizer.zero_grad()  # Clear gradients.
        model.zero_grad()

        data = data.to(device)
        # The format of this will depend on the model being used!
        out = model(data)  # Perform a single forward pass.


        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
    scheduler.step()


def test(loader):
    model.eval()
    print("Testing:")
    correct = 0
    running_loss = 0.0
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        out = model(data)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.

        loss = criterion(out, data.y)  # Compute the loss.
        num_samps_in_batch = int(data.batch[-1] + 1)
        running_loss += loss.item() * num_samps_in_batch
    
    acc = correct / len(loader.dataset)  # Derive ratio of correct predictions.
    loss = running_loss / len(loader.dataset)
    return acc, loss 



# Use tensorboard to track all of the models and the scores
val_accuracy = 0
val_loss_value = 0


for epoch in range(1, epochs+1):
    train()
    train_acc, train_loss = test(train_loader)
    val_acc, val_loss = test(validation_loader)

    # Save the model with the best accuracy
    if val_acc > val_accuracy:
        val_accuracy = val_acc
        val_loss_value = val_loss

Now, to show that the GraphNorm layer is not normalising correctly, I ran the code below with the model shown above:

import torch
from preprocessing.inMemoryDatasetClass import InMemDataset
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import pickle
from torch.optim import lr_scheduler
from models.GraphConv import MyModel
import sys
import numpy as np 
import sklearn.metrics as metrics
import awkward as ak
import matplotlib.pyplot as plt


# -----------------------------Define Model and Hyperparams-------------------------------
dataset_folder = "data/"
dataset_name = dataset_folder.split('/')[-3]
testing = False


# Define the model settings
settings = {
    'input_features' : 5,
    'conv_params' : [32, 32, 64, 64, 64, 128, 128, 128],
    'fc_params' : [(0.4, 128)],
    'output_classes' : 2 
}


# Define optimizer settings
lr = 0.05
gamma = 0.5
step_size = 30
decay = 0.00001

model_folder = 'model/'


# -----------------------------Create model and set optimizer-----------------------------
model = MyModel(settings)
print(model)
print(f"Number of parameters in model = {sum([param.nelement() for param in model.parameters()])}")
# Now to trasnfer to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(f"GPU = {torch.cuda.get_device_name(0)}")
model.to(device)

# Setup optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)
criterion = torch.nn.CrossEntropyLoss()
scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)



# -----------------------------Create train/validation datasets-----------------------------

# First let's read in the data into the MyDataset class 
dataset = InMemDataset(f"{dataset_folder}", test=testing)[:10]



# I then want to load in the best model and do the analysis work 
# Let me load the model parameters in
def load_checkpoint(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return model, optimizer, checkpoint['epoch']


checkpoint_path = f"{model_folder}/model.pt"
model, optimizer, start_epoch = load_checkpoint(checkpoint_path, model, optimizer)



# I want to see how the output changes for the first sample 
sample = dataset[0]
print(f"first element of x of the first sample = {sample.x[:1]}")


model.eval()
sample = sample.to(device)

print(f'With a single sample:')
out = model(sample)

out = F.softmax(out, dim=1)
print(f'original output = {out}')


# Now if I put it in a loader let's see how it changes 
print('Now using a batchsize of 2')
train_loader = DataLoader(dataset, batch_size=2, shuffle=False, pin_memory=True, num_workers = 8)
for sample_num, data in enumerate(train_loader):  # Iterate in batches over the training/test dataset.

    
    data = data.to(device)

    out = model(data)
    
    out = F.softmax(out, dim=1)
    print(f'output with batch size of 2 = {out}')
    print(f'Output for the first sample = {out[0]}')
    break

And the output for this was:

first element of x of the first sample = tensor([[1.7860e+00, 4.1688e+02, 1.6377e-01, 2.8861e+00, 0.0000e+00]])
With a single sample:
Before GraphNorm: tensor([[1.7860e+00, 4.1688e+02, 1.6377e-01, 2.8861e+00, 0.0000e+00]],
       device='cuda:0')
After GraphNorm: tensor([[ 0.9719,  1.8792,  1.1225,  0.0162, -0.1422]], device='cuda:0',
       grad_fn=<SliceBackward0>)
original output = tensor([[0.0078, 0.9922]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Now using a batchsize of 2
Before GraphNorm: tensor([[1.7860e+00, 4.1688e+02, 1.6377e-01, 2.8861e+00, 0.0000e+00]],
       device='cuda:0')
After GraphNorm: tensor([[ 0.7649,  1.7346,  0.5454,  0.0331, -0.2000]], device='cuda:0',
       grad_fn=<SliceBackward0>)
output with batch size of 2 = tensor([[8.5697e-01, 1.4303e-01],
        [4.0395e-05, 9.9996e-01]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Output for the first sample = tensor([0.8570, 0.1430], device='cuda:0', grad_fn=<SelectBackward0>)

As can be seen, the first sample has
x[:1] = [1.7860e+00, 4.1688e+02, 1.6377e-01, 2.8861e+00, 0.0000e+00]
and when this is put through the GraphNorm with only 1 sample this changes to x[:1]=[[ 0.9719, 1.8792, 1.1225, 0.0162, -0.1422]].
BUT, with a batchsize of 2, this changes to
x[:1]=[[ 0.7649, 1.7346, 0.5454, 0.0331, -0.2000]]
and the output for the sample changes massively from
[[0.0078, 0.9922]] to [0.8570, 0.1430]!

Any thoughts on why the GraphNorm isn’t working correctly here? Any help would be greatly appreciated!!

Thanks!

I don’t see any running stats in the GraphNorm implementation and neither is the self.training attribute used in the forward to switch between training and eval mode.
Where did you read that GraphNorm uses running stats during eval?

Ahhh my mistake, I guess I just assumed that BatchNorm and GraphNorm would follow a similar procedure to the normal pytorch nn batch_norm where it uses the running stats at eval (your answer does make a lot of sense though because I was very confused about how this was being implemented when I looked at the documentation!)
Apologies for the potential ignorance, but does this not become a slight problem as the output is then so heavily dependent on not only the batch size, but also the samples in the batch?
Thanks!

I took a quick look at the paper and it seems the authors were not concerned and apparently explicitly skipped the running stats.
Check 3.3:

However, the concentration of batch-level statistics is heavily domain-specific. […]
We study how the batch-level statistics μB , σB deviate from the dataset-level statistics μD , σD. […]
We observe thatfor image tasks, the maximal deviation of the batch-level statistics from the dataset-level statistics is negligible (Figure 4) after a few epochs. In contrast, for the graph tasks, the variation of batch-level statistics stays large during training. Intuitively, the graph structure can be quite diverse and the a single batch cannot well represent the entire dataset. Hence, the preconditioning property also may not hold for batchnorm. In fact, the heavy batch noise may bring instabilities to the training.