I have challenged myself to represent each image of MNIST dataset as a individual graph and later train GNN to classify each graph. To clarify - each node in the graph is a pixel, and each node has 2-4 edges to nearby pixels.
Though, I am getting the following error:
Traceback (most recent call last):
File "MNISTGraph2.py", line 103, in <module>
train(model, trainLoader, optimizer, device)
File "MNISTGraph2.py", line 32, in train
loss = F.nll_loss(output, data.y)
File ".conda\lib\site-packages\torch\nn\functional.py", line 2701, in nll_loss
return torch._C._nn.nll_loss_nd(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (784) to match target batch_size (1).
784 is the number of nodes (28*28 pixels) and 1 is the target value of the graph (number on MNIST image).
I believe the problem is that my target label is a graph-level, rather than node-level. Can you suggest how I should adjust the GNN to understand that I am classifying the whole graph, rather than each node?
Here is a google colab link to the code that produces the error: Google Colab
My code
from torchvision import transforms
import torch
from torchvision import datasets
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
trainDataset = datasets.MNIST(root="./datasets/MNIST", download=True, train=True, transform=transforms.ToTensor())
testDataset = datasets.MNIST(root="./datasets/MNIST", download=True, train=False,transform=transforms.ToTensor())
class myGNN(torch.nn.Module):
def __init__(self) -> None:
super(myGNN, self).__init__()
self.conv1 = GCNConv(1, 16)
self.conv2 = GCNConv(16, 10)
def forward(self, data):
x, edge_index = data.x, data.edge_index.to(torch.int64)
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
def train(model, loader, optimizer, device):
model.train()
for data in loader:
data = data.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, data.y)
loss.backward()
optimizer.step()
def test(model, loader, device):
model.eval()
correct = 0
with torch.no_grad():
for data in loader:
data = data.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(data.y).sum().item()
acc = correct / len(loader.dataset)
return acc
def getGraphs(dataset, length=None):
if length is None:
length = len(dataset)
graphs = []
for idx, image in enumerate(dataset.data):
if idx % 100 == 0:
print(idx)
if idx >= length:
break
imageHeight = image.shape[0]
imageWidth = image.shape[1]
numNodes = image.shape[0] * image.shape[1]
x = torch.zeros(numNodes, 1)
edge_index = torch.zeros(2, numNodes * 4 - (imageHeight + imageWidth - 2) * 2)
counter = 0
for i in range(imageHeight):
for j in range(imageWidth):
x[i*imageWidth + j] = image[i, j]
if i > 0:
edge_index[0, counter] = i*imageWidth + j
edge_index[1, counter] = (i-1)*imageWidth + j
counter += 1
if j > 0:
edge_index[0, counter] = i*imageWidth + j
edge_index[1, counter] = i*imageWidth + j - 1
counter += 1
if i < imageHeight - 1:
edge_index[0, counter] = i*imageWidth + j
edge_index[1, counter] = (i+1)*imageWidth + j
counter += 1
if j < imageWidth - 1:
edge_index[0, counter] = i*imageWidth + j
edge_index[1, counter] = i*imageWidth + j + 1
counter += 1
graph = Data(x=x, edge_index=edge_index, y=dataset.targets[idx])
graphs.append(graph)
return graphs
testGraphs = getGraphs(testDataset, length=100)
trainGraphs = getGraphs(trainDataset, length=600)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = myGNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
trainLoader = DataLoader(trainGraphs, shuffle=True)
testLoader = DataLoader(testGraphs, shuffle=False)
epochs = 50
for epoch in range(epochs):
train(model, trainLoader, optimizer, device)
print("Finished training")
test_acc = test(model, testLoader, device)
print('Epoch: {:03d}, Test Acc: {:.4f}'.format(epoch, test_acc))