Pooling layer in Heterogenous graph (Pytorch geometric)

Dear experts,

I am trying to use a heterogenous model on my heterogenous data.
I used the same model in the official documentation:

import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero, global_add_pool

data = dataset[0]

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

but when I tried to add a pooling layer:

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        x = global_add_pool(x, batch)
        return x


model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
print(model(data.x_dict, data.edge_index_dict, data.batch_dict))

it gives me this error message:

File /mnt/d/myname/GNN/hetero_classification.py:66, in train_graph_classifier()
     64 print()
     65 model = GNN()
---> 66 model = to_hetero(model, data.metadata(), aggr='max')
     67 out = model(data.x_dict, data.edge_index_dict, data.batch_dict)
     69 print(out)
File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/to_hetero_transformer.py:117, in to_hetero(module, metadata, aggr, input_map, debug)
     26 def to_hetero(module: Module, metadata: Metadata, aggr: str = "sum",
     27               input_map: Optional[Dict[str, str]] = None,
     28               debug: bool = False) -> GraphModule:
     29     r"""Converts a homogeneous GNN model into its heterogeneous equivalent in
     30     which node representations are learned for each node type in
     31     :obj:`metadata[0]`, and messages are exchanged between each edge type in
   (...)
    115             transformation in debug mode. (default: :obj:`False`)
    116     """
--> 117     transformer = ToHeteroTransformer(module, metadata, aggr, input_map, debug)
    118     return transformer.transform()

File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/to_hetero_transformer.py:141, in ToHeteroTransformer.__init__(self, module, metadata, aggr, input_map, debug)
    133 def __init__(
    134     self,
    135     module: Module,
   (...)
    139     debug: bool = False,
    140 ):
--> 141     super().__init__(module, input_map, debug)
    143     unused_node_types = get_unused_node_types(*metadata)
    144     if len(unused_node_types) > 0:

File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/fx.py:73, in Transformer.__init__(self, module, input_map, debug)
     66 def __init__(
     67     self,
     68     module: Module,
     69     input_map: Optional[Dict[str, str]] = None,
     70     debug: bool = False,
     71 ):
     72     self.module = module
---> 73     self.gm = symbolic_trace(module)
     74     self.input_map = input_map
     75     self.debug = debug

File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/fx.py:358, in symbolic_trace(module, concrete_args)
    354         self.submodule_paths = None
    356         return self.graph
--> 358 return GraphModule(module, Tracer().trace(module, concrete_args))

File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/fx.py:351, in symbolic_trace.<locals>.Tracer.trace(self, root, concrete_args)
    347     for module in self._autowrap_search:
    348         st._autowrap_check(patcher, module.__dict__,
    349                            self._autowrap_function_ids)
    350     self.create_node(
--> 351         'output', 'output', (self.create_arg(fn(*args)), ), {},
    352         type_expr=fn.__annotations__.get('return', None))
    354 self.submodule_paths = None
    356 return self.graph

File /mnt/d/myname/GNN/models.py:36, in GNN.forward(self, x, edge_index, batch)
     34 x = self.conv1(x, edge_index).relu()
     35 x = self.conv2(x, edge_index)
---> 36 out = global_add_pool(x, batch)
     37 
     38 return x

File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/pool/glob.py:30, in global_add_pool(x, batch, size)
     28 if batch is None:
     29     return x.sum(dim=-2, keepdim=x.dim() == 2)
---> 30 size = int(batch.max().item() + 1) if size is None else size
     31 return scatter(x, batch, dim=-2, dim_size=size, reduce='add')

TypeError: int() argument must be a string, a bytes-like object or a number, not 'Proxy'

Could you please help solving this issue?

Thank you very much

The solution is to use MeanAggregation() as used here.

Thanks to Matthias Fey