Test accuracy for GCN is zero

I implemented the Graph Convolutional Network (GCN) for node classification, but I got test accuracy is zero.

My code:

import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv #GATConv

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(42)

        # Initialize the layers
        self.conv1 = GCNConv(data.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.out = Linear(hidden_channels, 8)

    def forward(self, x, edge_index):
        # First Message Passing Layer (Transformation)
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)

        # Second Message Passing Layer
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)

        # Output layer 
        x = F.softmax(self.out(x), dim=1)
        return x

model = GCN(hidden_channels=16)

Training and Evaluation

# Initialize model
model = GCN(hidden_channels=16)

# Initialize Optimizer
learning_rate = 0.01
decay = 5e-4
optimizer = torch.optim.Adam(model.parameters(), 
                             lr=learning_rate, 
                             weight_decay=decay)
# Define loss function (CrossEntropyLoss for Classification Problems with 
# probability distributions)
criterion = torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad() 
      # Use all data as input, because all nodes have node features
      out = model(data.x.float(), data.edge_index)  
      # Only use nodes with labels available for loss calculation --> mask
      loss = criterion(out[data.train_mask], data.y[data.train_mask])  
      loss.backward() 
      optimizer.step()
      return loss

def test():
      model.eval()
      out = model(data.x.float(), data.edge_index)
      # Use the class with highest probability.
      pred = out.argmax(dim=1)  
      # Check against ground-truth labels.
      test_correct = pred[data.test_mask] == data.y[data.test_mask]  
      # Derive ratio of correct predictions.
      test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  
      return test_acc

losses = []
for epoch in range(0, 100):
    loss = train()
    test_acc = test()
    losses.append(loss)
    if epoch % 10 == 0:
      print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')

The output:

Epoch: 000, Loss: 2.2735, Test Accuracy: 0.0000
Epoch: 010, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 020, Loss: 2.2740, Test Accuracy: 0.0000
Epoch: 030, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 040, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 050, Loss: 2.2740, Test Accuracy: 0.0000
Epoch: 060, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 070, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 080, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 090, Loss: 1.2740, Test Accuracy: 0.0000

Unrelated issue, but remove the F.softmax operation in forward as nn.CrossEntropyLoss expects raw logits.

Could you check what the model predicts and what the corresponding targets are here?

pred[data.test_mask] == data.y[data.test_mask]  

I would expect to see at least the “random” accuracy.

---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
Input In [34], in <module>
     43 for data in graphData:
     44   #print(data)
     45   loss = train(data)
---> 46   test_acc = test(data)
     47 #print(test_acc)
     48 losses.append(loss)

Input In [34], in test(data)
     36 print(data.test_mask.sum())
     37 # Derive ratio of correct predictions.
---> 38 test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  
     39 return test_acc

ZeroDivisionError: division by zero

You are currently dividing by the sum of test_mask. If you want to calculate the accuracy you should divide by the number of entries/pixels, not the sum of the class indices.

1 Like