I implemented the Graph Convolutional Network (GCN) for node classification, but I got test accuracy is zero.
My code:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv #GATConv
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GCN, self).__init__()
torch.manual_seed(42)
# Initialize the layers
self.conv1 = GCNConv(data.num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.out = Linear(hidden_channels, 8)
def forward(self, x, edge_index):
# First Message Passing Layer (Transformation)
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
# Second Message Passing Layer
x = self.conv2(x, edge_index)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
# Output layer
x = F.softmax(self.out(x), dim=1)
return x
model = GCN(hidden_channels=16)
Training and Evaluation
# Initialize model
model = GCN(hidden_channels=16)
# Initialize Optimizer
learning_rate = 0.01
decay = 5e-4
optimizer = torch.optim.Adam(model.parameters(),
lr=learning_rate,
weight_decay=decay)
# Define loss function (CrossEntropyLoss for Classification Problems with
# probability distributions)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad()
# Use all data as input, because all nodes have node features
out = model(data.x.float(), data.edge_index)
# Only use nodes with labels available for loss calculation --> mask
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss
def test():
model.eval()
out = model(data.x.float(), data.edge_index)
# Use the class with highest probability.
pred = out.argmax(dim=1)
# Check against ground-truth labels.
test_correct = pred[data.test_mask] == data.y[data.test_mask]
# Derive ratio of correct predictions.
test_acc = int(test_correct.sum()) / int(data.test_mask.sum())
return test_acc
losses = []
for epoch in range(0, 100):
loss = train()
test_acc = test()
losses.append(loss)
if epoch % 10 == 0:
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')
The output:
Epoch: 000, Loss: 2.2735, Test Accuracy: 0.0000
Epoch: 010, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 020, Loss: 2.2740, Test Accuracy: 0.0000
Epoch: 030, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 040, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 050, Loss: 2.2740, Test Accuracy: 0.0000
Epoch: 060, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 070, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 080, Loss: 1.2740, Test Accuracy: 0.0000
Epoch: 090, Loss: 1.2740, Test Accuracy: 0.0000