ValueError: Expected input batch_size (3622) to match target batch_size (128) in Pytorch_Geometric

I spend a significant amount of time on this issue and search all possible solutions but without any success. I have no idea where the number 3622 come from? If I understand correctly, the input batch_size and target batch_size are two components of the loss function but I don’t know how torch come out this 3622 number. The following is my code

Any suggestions are deeply appreciate!!!

import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn.functional as F 
from dataset import MoleculeDataset
from torch_geometric.loader import DataLoader
from sklearn.metrics import confusion_matrix, f1_score,\
                            accuracy_score, precision_score,\
                            recall_score, roc_auc_score
from torch.nn import Linear, BatchNorm1d, ModuleList, Sequential, ReLU, LeakyReLU
from torch_geometric.nn import GATConv


# hyperparameters        
seed = 2022                              # random state
epochs = 1                               # num of epochs
batch_size = 128                         # num of graph per batch 

# load the datasets
train_dataset = MoleculeDataset(root= f"{os.getcwd()}/", filename="TrainSet.csv")
test_dataset = MoleculeDataset(root= f"{os.getcwd()}/", filename="TestSet.csv")


# training the model
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')         # run model on GPU if its possible
device = torch.device('cpu')                                                    # run model on CPU
model = GNN().to(device)

# load the model
model = GNN(feature_size=train_dataset.num_features)

# count number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(model))
print(f"model architecture: {model}")

# loss function & optimizer
weights = torch.tensor([1, 5], dtype=torch.float32).to(device)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                       factor=0.7, patience=5,
                                                       min_lr=1e-5)
# prepare the training
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = True)

# train a model
def train(epoch):
    # enumerate over the data
    all_preds, all_labels = [], []

    for _, batch in enumerate(tqdm(train_loader)):
        # use GPU if GPU is available
        batch.to(device)
        # reset gradients
        optimizer.zero_grad()
        # pass the node features and the connection info
        pred = model(batch.x.float(),
                     batch.edge_attr.float(),
                     batch.edge_index,
                     batch.batch)
        # calculate the loss and the gradients
        loss = torch.sqrt(loss_fn(pred, batch.y))
        loss.backward()
        # update using the gradients
        optimizer.step()

        all_preds.append(np.argmax(pred.detach().numpy(), axis=1))
        all_labels.append(batch.y.detach().numpy())
    
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    calculate_metrics(all_labels, all_preds, epoch, "train")
    return loss

def test(epoch):
    all_preds, all_labels = [], []
    for batch in test_loader:
        batch.to(device)
        pred = model(batch.x.float(),
                     batch.edge_attr.float(),
                     batch.edge_index,
                     batch.batch)
        loss = torch.sqrt(loss_fn(pred, batch.y))
        all_preds.append(np.argmax(pred.detach().numpy(), axis =1))
        all_labels.append(batch.y.detach().numpy())

        all_preds = np.concatenate(all_preds).ravel()
        all_labels = np.concatenate(all_labels).ravel()
        calculate_metrics(all_labels, all_preds, epoch, "test")
        return loss

def calculate_metrics(y_true, y_pred, epoch, type):
    print(f"confusion matrix: \n {confusion_matrix(y_true, y_pred)}")

# start training
for epoch in range(epochs):
    # training
    model.train()
    train_loss = train(epoch=epoch)
    train_loss = train_loss.detach().numpy()
    scheduler.step(train_loss)
    print(f"Epoch: {epoch+1} | Train Loss: {train_loss:0.4f}")
    print("="*90)

    # testing
    if epoch%1 == 0:
        test_loss = test(epoch=epoch)
        test_loss = test_loss.detach().numpy()
        print(f"Epoch: {epoch+1} | Test Loss: {test_loss:0.4f}")

The error message is

  0%|          | 0/7340 [00:00<?, ?it/s]
1st x: torch.Size([3622, 800])
2nd x: torch.Size([3622, 200])
3rd x: torch.Size([3622, 200])
4th x: torch.Size([3622, 2])

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [12], in <cell line: 2>()
      2 for epoch in range(epochs):
      3     # training
      4     model.train()
----> 5     train_loss = train(epoch=epoch)
      6     train_loss = train_loss.detach().numpy()
      7     scheduler.step(train_loss)

Input In [11], in train(epoch)
     12 pred = model(batch.x.float(),
     13              batch.edge_attr.float(),
     14              batch.edge_index,
     15              batch.batch)
     16 # calculate the loss and the gradients
---> 17 loss = torch.sqrt(loss_fn(pred, batch.y))
     18 loss.backward()
     19 # update using the gradients

File ~/.conda/envs/radips/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.conda/envs/radips/lib/python3.8/site-packages/torch/nn/modules/loss.py:1163, in CrossEntropyLoss.forward(self, input, target)
   1162 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1163     return F.cross_entropy(input, target, weight=self.weight,
   1164                            ignore_index=self.ignore_index, reduction=self.reduction,
   1165                            label_smoothing=self.label_smoothing)

File ~/.conda/envs/radips/lib/python3.8/site-packages/torch/nn/functional.py:2996, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   2994 if size_average is not None or reduce is not None:
   2995     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2996 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

ValueError: Expected input batch_size (3622) to match target batch_size (128).

This error is usually raised if users are flattening intermediate tensors in a wrong way and are changing the batch size.
E.g. this is often used:

x = x.view(-1, feature_size)

which would change the batch size of x of feature_size does not match the number of elements of all remaining dimensions.
The proper approach is to use:

x = x.view(x.size(0), -1)

to flatten a tensor, which would keep the batch size the same and would potentially raise a shape mismatch in the next layer in case the number of features is unexpected.

Check your model for these code snippets and try to fix them.

Thank you, @ptrblck. I fix it by adding the following code in my forward function.

x = global_mean_pool(x, batch)