I made a composite model MainModel
which consist of a GinEncoder
and a MainModel
which containing some Linear
layers, and the GinEncoder
made by the package torch-geometric
, show as following codes :
class GinEncoder(torch.nn.Module):
def __init__(self):
super(GinEncoder, self).__init__()
self.gin_convs = torch.nn.ModuleList()
self.gin_convs.append(GINConv(Sequential(Linear(1, 4),
BatchNorm1d(4), ReLU(),
Linear(4, 4), ReLU())))
self.gin_convs.append(GINConv(Sequential(Linear(4, 4),
BatchNorm1d(4), ReLU(),
Linear(4, 4), ReLU())))
def forward(self, x, edge_index, batch_node_id):
# Node embeddings
nodes_emb_layers = []
for i in range(2):
x = self.gin_convs[i](x, edge_index)
nodes_emb_layers.append(x)
# Graph-level readout
nodes_emb_pools = [global_add_pool(nodes_emb, batch_node_id) for nodes_emb in nodes_emb_layers]
# Concatenate and form the graph embeddings
graph_embeds = torch.cat(nodes_emb_pools, dim=1)
return graph_embeds
def get_embeddings(self, x, edge_index, batch_node_id):
with torch.no_grad():
graph_embeds = self.forward(x, edge_index, batch_node_id).reshape(-1)
return graph_embeds
class MainModel(torch.nn.Module):
def __init__(self, graph_encoder:torch.nn.Module):
super(MainModel, self).__init__()
self.graph_encoder = graph_encoder
self.lin1 = Linear(8, 4)
self.lin2 = Linear(4, 8)
def forward(self, x, edge_index, batch_node_id):
graph_embeds = self.graph_encoder(x, edge_index, batch_node_id)
out_lin1 = self.lin1(graph_embeds)
pred = self.lin2(out_lin1)[-1]
return pred
gin_encoder = GinEncoder().to("cuda")
model = MainModel(gin_encoder).to("cuda")
I found that the weights of GinEncoder
were not updated, while the weights of Linear
layer in MainModel
were updated.I observe this by following codes:
gin_encoder = GinEncoder().to("cuda")
model = MainModel(gin_encoder).to("cuda")
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
epochs =
for epoch_i in range(epochs):
model.train()
train_loss = 0
for batch_i, data in enumerate(train_loader):
data.to("cuda")
x, x_edge_index, x_batch_node_id = data.x, data.edge_index, data.batch
y, y_edge_index, y_batch_node_id = data.y[-1].x, data.y[-1].edge_index, torch.zeros(data.y[-1].x.shape[0], dtype=torch.int64).to("cuda")
optimizer.zero_grad()
graph_embeds_pred = model(x, x_edge_index, x_batch_node_id)
y_graph_embeds = model.graph_encoder.get_embeddings(y, y_edge_index, y_batch_node_id)
loss = criterion(graph_embeds_pred, y_graph_embeds)
train_loss += loss
loss.backward()
optimizer.step()
if batch_i == 0:
print(f"NO. {epoch_i} EPOCH")
print(f"MainModel weights in epoch_{epoch_i}_batch0:{next(islice(model.parameters(), 15, 16))}", end="\n\n")
print(f"GinEncoder weights in epoch_{epoch_i}_batch0:{next(model.graph_encoder.parameters())}")
print("*"*80)
Outputs of codes:
NO. 0 EPOCH
MainModel weights in epoch_0_batch0:Parameter containing:
tensor([-0.1447, -0.3689, -0.2840, -0.3619, -0.2040, 0.2430, 0.4651, 0.3736],
device='cuda:0', requires_grad=True)
GinEncoder weights in epoch_0_batch0:Parameter containing:
tensor([[-0.8312],
[-0.5712],
[-0.6963],
[-0.1601]], device='cuda:0', requires_grad=True)
********************************************************************************
NO. 1 EPOCH
MainModel weights in epoch_1_batch0:Parameter containing:
tensor([-0.1842, -0.3333, -0.3170, -0.3247, -0.2424, 0.2627, 0.4272, 0.4119],
device='cuda:0', requires_grad=True)
GinEncoder weights in epoch_1_batch0:Parameter containing:
tensor([[-0.8312],
[-0.5712],
[-0.6963],
[-0.1601]], device='cuda:0', requires_grad=True)
********************************************************************************
NO. 2 EPOCH
MainModel weights in epoch_2_batch0:Parameter containing:
tensor([-0.2302, -0.3077, -0.3251, -0.2905, -0.2847, 0.2558, 0.3881, 0.4527],
device='cuda:0', requires_grad=True)
GinEncoder weights in epoch_2_batch0:Parameter containing:
tensor([[-0.8312],
[-0.5712],
[-0.6963],
[-0.1601]], device='cuda:0', requires_grad=True)
********************************************************************************
My question is how to make loss.backward()
and optimizer.step()
also pass to GinEncoder
?
PS.
- I put the complete codes in here: a composite model composed of pytorch and torch-geometric · GitHub
- I put the training data on Google Drive: tmp - Google Drive