Expected input batch_size (784) to match target batch_size (1)

I have challenged myself to represent each image of MNIST dataset as a individual graph and later train GNN to classify each graph. To clarify - each node in the graph is a pixel, and each node has 2-4 edges to nearby pixels.

Though, I am getting the following error:

Traceback (most recent call last):
  File "MNISTGraph2.py", line 103, in <module>
    train(model, trainLoader, optimizer, device)
  File "MNISTGraph2.py", line 32, in train
    loss = F.nll_loss(output, data.y)
  File ".conda\lib\site-packages\torch\nn\functional.py", line 2701, in nll_loss
    return torch._C._nn.nll_loss_nd(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (784) to match target batch_size (1).

784 is the number of nodes (28*28 pixels) and 1 is the target value of the graph (number on MNIST image).

I believe the problem is that my target label is a graph-level, rather than node-level. Can you suggest how I should adjust the GNN to understand that I am classifying the whole graph, rather than each node?

Here is a google colab link to the code that produces the error: Google Colab

My code
from torchvision import transforms
import torch
from torchvision import datasets
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch_geometric.loader import DataLoader

trainDataset = datasets.MNIST(root="./datasets/MNIST", download=True, train=True, transform=transforms.ToTensor())
testDataset = datasets.MNIST(root="./datasets/MNIST", download=True, train=False,transform=transforms.ToTensor())

class myGNN(torch.nn.Module):
    def __init__(self) -> None:
        super(myGNN, self).__init__()
        self.conv1 = GCNConv(1, 16)
        self.conv2 = GCNConv(16, 10)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index.to(torch.int64)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
    

def train(model, loader, optimizer, device):
    model.train()
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, data.y)
        loss.backward()
        optimizer.step()

def test(model, loader, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(data.y).sum().item()
    acc = correct / len(loader.dataset)
    return acc

def getGraphs(dataset, length=None):
    if length is None:
        length = len(dataset)
    graphs = []
    for idx, image in enumerate(dataset.data):
        if idx % 100 == 0:
            print(idx)
        if idx >= length:
            break
        imageHeight = image.shape[0]
        imageWidth = image.shape[1]
        numNodes = image.shape[0] * image.shape[1]
        
        x = torch.zeros(numNodes, 1)
        edge_index = torch.zeros(2, numNodes * 4 - (imageHeight + imageWidth - 2) * 2)

        counter = 0
        for i in range(imageHeight):
            for j in range(imageWidth):
                x[i*imageWidth + j] = image[i, j]
                if i > 0:
                    edge_index[0, counter] = i*imageWidth + j
                    edge_index[1, counter] = (i-1)*imageWidth + j
                    counter += 1
                if j > 0:
                    edge_index[0, counter] = i*imageWidth + j
                    edge_index[1, counter] = i*imageWidth + j - 1
                    counter += 1
                if i < imageHeight - 1:
                    edge_index[0, counter] = i*imageWidth + j
                    edge_index[1, counter] = (i+1)*imageWidth + j
                    counter += 1
                if j < imageWidth - 1:
                    edge_index[0, counter] = i*imageWidth + j
                    edge_index[1, counter] = i*imageWidth + j + 1
                    counter += 1
        graph = Data(x=x, edge_index=edge_index, y=dataset.targets[idx])
        graphs.append(graph)
    return graphs

testGraphs = getGraphs(testDataset, length=100)
trainGraphs = getGraphs(trainDataset, length=600)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = myGNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

trainLoader = DataLoader(trainGraphs, shuffle=True)
testLoader = DataLoader(testGraphs, shuffle=False)

epochs = 50
for epoch in range(epochs):
    train(model, trainLoader, optimizer, device)
    print("Finished training")
    test_acc = test(model, testLoader, device)
    print('Epoch: {:03d}, Test Acc: {:.4f}'.format(epoch, test_acc))

Based on the error message it seems your model outputs logits for each pixel location in the batch dimension while the common output for a multi-class classification would have the shape [batch_size, nb_classes].
I’m not that familiar with PyG but could you check the output shape and see if it would match your use case?