def __init__(self, graphs, max_nodes):
self.graphs = graphs
self.max_nodes = max_nodes
self.num_nodes = max([edge_index.size(1) for _, edge_index, _, _ in graphs])
# print(max_nodes)
def __len__(self):
return len(self.graphs)
def __getitem__(self, idx):
graph_name, edge_index, edge_attr, max_flow = self.graphs[idx]
# Pad the edge_index and edge_attr with zeros
padded_edge_index = torch.zeros((2, self.max_nodes), dtype=torch.long)
padded_edge_index[:, :edge_index.size(1)] = edge_index
padded_edge_attr = torch.zeros(self.max_nodes, dtype=torch.float)
padded_edge_attr[:edge_attr.size(0)] = edge_attr
# Pad x with zeros and return dummy node feature matrix
x = torch.ones(self.num_nodes, 1, dtype=torch.float)
return x, padded_edge_index, padded_edge_attr, max_flow
class MPNN(nn.Module):
def __init__(self, num_node_features, num_edge_features, num_nodes, max_nodes, hidden_dim, output_dim):
super(MPNN, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Linear(220, hidden_dim), # Update input size to 220 here
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# Process
self.process = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim + max_nodes, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
for _ in range(2)
])
# Decoder
self.decoder = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
# Linear layer
self.fc = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x, edge_index, edge_attr):
batch_size, num_nodes, num_node_features = x.size()
x = x.view(batch_size, -1)
# edge_attr = edge_attr.view(batch_size, -1)
edge_attr = edge_attr.view(batch_size, max_nodes)
# output=Input shape before encoder: torch.Size([16, 110, 1]) torch.Size([16, 110])
x = self.encoder(torch.cat([x, edge_attr], dim=1))
for i, process_layer in enumerate(self.process):
x = self.propagate(edge_index, x, edge_attr, process_layer)
x = self.decoder(x)
x = self.fc(x)
return x
def propagate(self, edge_index, x, edge_attr, process_layer):
row, col = edge_index[0], edge_index[1]
# Message Passing
edge_attr_expanded = edge_attr.unsqueeze(-1)
message = torch.cat([x[row], edge_attr_expanded.t()], dim=1)
message = process_layer(message)
message = F.relu(message)
# Reduce
aggr_message = torch.zeros(x.shape[0], self.hidden_dim).to(device)
aggr_message.index_add_(0, col, message)
x = x + aggr_message # Changed torch.cat to x + aggr_message
return x
some help pliz.
Pl. mention the values you used for num_node_features, num_edge_features, num_nodes, max_nodes, hidden_dim, output_dim
, as without this, it’s not possible to answer this question.
Number of node features: 11
Number of edge features: 12
Number of nodes: 11
Maximum number of nodes: 110
Hidden dimension: 16
Output dimension: 1
I am not sure if this code is syntax error-free. For example, the below line in forward()
method does not have access to max_nodes
. should it be self.max_nodes
?
Please write a minimal working example (with class instantiation code + data) to show how you use the model so that a copy-paste-execute would directly lead to the particular error.