I am trying to implement a simple Heterogeneous Graph Learning example but I am getting the “NotImplementedError”. I tried following the instructions here but I am not sure what am I missing. Could someone please point out the problem in my code? Thanks in advance.
Below is my code:
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import to_hetero
from torch_geometric.nn import SAGEConv
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
data = HeteroData()
data['node_type_1'].x = torch.rand((2,1))
data['node_type_2'].x = torch.rand((8,4))
type1_type2_edges = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]])
type2_type2_edges = torch.tensor([[0,1],[0,2],[0,3],[1,0],[1,2],[1,3],
[2,0],[2,1],[2,3],[3,0],[3,1],[3,2],
[4,5],[4,6],[4,7],[5,4],[5,6],[5,7],
[6,4],[6,5],[6,7],[7,4],[7,5],[7,6],
[0,4],[1,5],[2,6],[3,7]])
data['node_type_1', 'actuates', 'node_type_2'].edge_index = type1_type2_edges.t().contiguous()
data['node_type_2', 'interconnect', 'node_type_2'].edge_index = type2_type2_edges.t().contiguous()
model = GNN(hidden_channels=64, out_channels=4)
model = to_hetero(model, data.metadata(), aggr='sum')