Thank You for reply.
I have tried to merge the two class: MainModel and GinEncoder() to avoid the unchanging parameters of gin_covs().
It works, it made the parameters of gin_convs() change, but this approach is still not answer my confusion
Here is the update codes:
class MainModel2(torch.nn.Module):
def __init__(self):
super(MainModel2, self).__init__()
self.gin_convs = torch.nn.ModuleList()
self.gin_convs.append(GINConv(Sequential(Linear(1, 4), ReLU(),
Linear(4, 4), ReLU(),
BatchNorm1d(4))))
self.gin_convs.append(GINConv(Sequential(Linear(4, 4), ReLU(),
Linear(4, 4), ReLU(),
BatchNorm1d(4))))
self.lin1 = Linear(8, 4)
self.lin2 = Linear(4, 8)
def forward(self, x, edge_index, batch_node_id):
# Node embeddings
nodes_emb_layers = []
for i in range(2):
x = self.gin_convs[i](x, edge_index)
nodes_emb_layers.append(x)
# Graph-level readout
nodes_emb_pools = [global_add_pool(nodes_emb, batch_node_id) for nodes_emb in nodes_emb_layers]
# Concatenate and form the graph embeddings
graph_embeds = torch.cat(nodes_emb_pools, dim=1)
out_lin1 = self.lin1(graph_embeds)
pred = self.lin2(out_lin1)[-1]
return pred