Batch size not matching in heterogeneous graph classification

I am trying to use batches with my heterogeneous model. I am not able to use global_mean_pool in the forward function since I get the error:
TraceError('symbolically traced variables cannot be used as inputs to control flow')

Tried using global_mean_pool as suggested here but I get the error:
'NodeStorage' object has no attribute 'max'

Instead, I was trying to use MeanAggregation() but now the model does not group by batches. Error is now:
ValueError: Expected input batch_size (1) to match target batch_size (2).

I reviewed this post but x.size returns Proxy(getattr_2) so I can’t see what size it needs to be.

The code works if the batch size is 1. How can I make it work with a larger batch size?

Here is a minimal reproducible example with some fake data:

from import HeteroData
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F
from torch.nn import Linear, ReLU
from torch_geometric.loader import DataLoader
from import InMemoryDataset
from torch_geometric.nn import MeanAggregation
import torch_geometric.transforms as T
from import Subset
import numpy as np

graphs = []
for i in range(10):
    data = HeteroData()
    data['node1'].x = torch.tensor([0, 1, 2]).reshape([-1, 1])
    data['node1'].x = data['node1'].x.type(torch.FloatTensor)
    data['node2'].x = torch.tensor([0, 1, 2]).reshape([-1, 1])
    data['node2'].x = data['node2'].x.type(torch.FloatTensor)
    data['node1', 'edge1', 'node2'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]])
    data['node1', 'edge1', 'node2'].edge_index.type(torch.LongTensor)
    data['node1', 'edge1', 'node2'].edge_attr =torch.tensor([123, 101.5]).reshape([-1, 1])
    data['node2', 'edge2', 'node2'].edge_index = torch.tensor([[0, 1], [1, 2]])
    data['node2', 'edge2', 'node2'].edge_index.type(torch.LongTensor)
    data['node2', 'edge2', 'node2'].edge_attr = torch.tensor([1.23, 0.34]).reshape([-1, 1])
    data.y = torch.tensor([1])
    data.y = data.y.type(torch.LongTensor)

transform = T.ToUndirected()
dataset = []
for graph in graphs:
    graph = transform(graph)

data = dataset[0]  # Get the first graph object.


# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


N = len(dataset)

# generate & shuffle indices
indices = np.arange(N)
indices = np.random.permutation(indices)

# select train/test with split 75/25
train_indices = indices [:int(0.75*N)]
test_indices = indices[int(0.75*N):]

train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        self.conv1 = SAGEConv((-1,-1), hidden_channels)
        self.conv2 = SAGEConv((-1,-1), hidden_channels)
        self.pool = MeanAggregation()
        self.lin = Linear(hidden_channels, 11)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        x = self.pool(x, batch)
        x = F.dropout(x, p=0.5,
        x = self.lin(x)
        return x
model = GNN(hidden_channels=64, out_channels=1)
model = to_hetero(model, data.metadata(), aggr='sum')

def train():

    for batch in train_loader:  # Iterate in batches over the training dataset.
         out = model(batch.x_dict, batch.edge_index_dict, batch)  # Perform a single forward pass.
         loss = criterion(out, batch.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):

    correct = 0
    for batch in loader:  # Iterate in batches over the training/test dataset.
        out = model(batch.x_dict, batch.edge_index_dict, batch)  
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == batch.y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.

optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(1, 21):
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {train_acc:.4f}, Val: {test_acc:.4f}')