Dear experts,
I am trying to use a heterogenous model on my heterogenous data.
I used the same model in the official documentation:
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero, global_add_pool
data = dataset[0]
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
but when I tried to add a pooling layer:
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
x = global_add_pool(x, batch)
return x
model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
print(model(data.x_dict, data.edge_index_dict, data.batch_dict))
it gives me this error message:
File /mnt/d/myname/GNN/hetero_classification.py:66, in train_graph_classifier()
64 print()
65 model = GNN()
---> 66 model = to_hetero(model, data.metadata(), aggr='max')
67 out = model(data.x_dict, data.edge_index_dict, data.batch_dict)
69 print(out)
File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/to_hetero_transformer.py:117, in to_hetero(module, metadata, aggr, input_map, debug)
26 def to_hetero(module: Module, metadata: Metadata, aggr: str = "sum",
27 input_map: Optional[Dict[str, str]] = None,
28 debug: bool = False) -> GraphModule:
29 r"""Converts a homogeneous GNN model into its heterogeneous equivalent in
30 which node representations are learned for each node type in
31 :obj:`metadata[0]`, and messages are exchanged between each edge type in
(...)
115 transformation in debug mode. (default: :obj:`False`)
116 """
--> 117 transformer = ToHeteroTransformer(module, metadata, aggr, input_map, debug)
118 return transformer.transform()
File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/to_hetero_transformer.py:141, in ToHeteroTransformer.__init__(self, module, metadata, aggr, input_map, debug)
133 def __init__(
134 self,
135 module: Module,
(...)
139 debug: bool = False,
140 ):
--> 141 super().__init__(module, input_map, debug)
143 unused_node_types = get_unused_node_types(*metadata)
144 if len(unused_node_types) > 0:
File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/fx.py:73, in Transformer.__init__(self, module, input_map, debug)
66 def __init__(
67 self,
68 module: Module,
69 input_map: Optional[Dict[str, str]] = None,
70 debug: bool = False,
71 ):
72 self.module = module
---> 73 self.gm = symbolic_trace(module)
74 self.input_map = input_map
75 self.debug = debug
File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/fx.py:358, in symbolic_trace(module, concrete_args)
354 self.submodule_paths = None
356 return self.graph
--> 358 return GraphModule(module, Tracer().trace(module, concrete_args))
File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/fx.py:351, in symbolic_trace.<locals>.Tracer.trace(self, root, concrete_args)
347 for module in self._autowrap_search:
348 st._autowrap_check(patcher, module.__dict__,
349 self._autowrap_function_ids)
350 self.create_node(
--> 351 'output', 'output', (self.create_arg(fn(*args)), ), {},
352 type_expr=fn.__annotations__.get('return', None))
354 self.submodule_paths = None
356 return self.graph
File /mnt/d/myname/GNN/models.py:36, in GNN.forward(self, x, edge_index, batch)
34 x = self.conv1(x, edge_index).relu()
35 x = self.conv2(x, edge_index)
---> 36 out = global_add_pool(x, batch)
37
38 return x
File ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch_geometric/nn/pool/glob.py:30, in global_add_pool(x, batch, size)
28 if batch is None:
29 return x.sum(dim=-2, keepdim=x.dim() == 2)
---> 30 size = int(batch.max().item() + 1) if size is None else size
31 return scatter(x, batch, dim=-2, dim_size=size, reduce='add')
TypeError: int() argument must be a string, a bytes-like object or a number, not 'Proxy'
Could you please help solving this issue?
Thank you very much