PyTorch Geometric GIN-Conv layers parameters not updating

I made a composite model MainModel which consist of a GinEncoder and a MainModel which containing some Linear layers, and the GinEncoder made by the package torch-geometric, show as following codes :

class GinEncoder(torch.nn.Module):
    def __init__(self):
        super(GinEncoder, self).__init__()
        self.gin_convs = torch.nn.ModuleList()
        self.gin_convs.append(GINConv(Sequential(Linear(1, 4),
                                                 BatchNorm1d(4), ReLU(),
                                                 Linear(4, 4), ReLU())))
        self.gin_convs.append(GINConv(Sequential(Linear(4, 4),
                                                 BatchNorm1d(4), ReLU(),
                                                 Linear(4, 4), ReLU())))


    def forward(self, x, edge_index, batch_node_id):
        # Node embeddings
        nodes_emb_layers = []
        for i in range(2):
            x = self.gin_convs[i](x, edge_index)
            nodes_emb_layers.append(x)

        # Graph-level readout
        nodes_emb_pools = [global_add_pool(nodes_emb, batch_node_id) for nodes_emb in nodes_emb_layers]

        # Concatenate and form the graph embeddings
        graph_embeds = torch.cat(nodes_emb_pools, dim=1)
        return graph_embeds


    def get_embeddings(self, x, edge_index, batch_node_id):
        with torch.no_grad():
            graph_embeds = self.forward(x, edge_index, batch_node_id).reshape(-1)

        return graph_embeds


class MainModel(torch.nn.Module):
    def __init__(self, graph_encoder:torch.nn.Module):
        super(MainModel, self).__init__()
        self.graph_encoder = graph_encoder
        self.lin1 = Linear(8, 4)
        self.lin2 = Linear(4, 8)


    def forward(self, x, edge_index, batch_node_id):
        graph_embeds = self.graph_encoder(x, edge_index, batch_node_id)
        out_lin1 = self.lin1(graph_embeds)
        pred = self.lin2(out_lin1)[-1]

        return pred

gin_encoder = GinEncoder().to("cuda")
model =  MainModel(gin_encoder).to("cuda")

I found that the weights of GinEncoder were not updated, while the weights of Linear layer in MainModel were updated.I observe this by following codes:

gin_encoder = GinEncoder().to("cuda")
model =  MainModel(gin_encoder).to("cuda")
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
epochs = 

for epoch_i in range(epochs):
    model.train()
    train_loss = 0

    for batch_i, data in enumerate(train_loader):
        data.to("cuda")
        x, x_edge_index, x_batch_node_id = data.x, data.edge_index, data.batch
        y, y_edge_index, y_batch_node_id = data.y[-1].x, data.y[-1].edge_index, torch.zeros(data.y[-1].x.shape[0], dtype=torch.int64).to("cuda")
        optimizer.zero_grad()
        graph_embeds_pred = model(x, x_edge_index, x_batch_node_id)
        y_graph_embeds = model.graph_encoder.get_embeddings(y, y_edge_index, y_batch_node_id)
        loss =  criterion(graph_embeds_pred, y_graph_embeds)
        train_loss += loss
        loss.backward()
        optimizer.step()
        if batch_i == 0:
            print(f"NO. {epoch_i} EPOCH")
            print(f"MainModel weights in epoch_{epoch_i}_batch0:{next(islice(model.parameters(), 15, 16))}", end="\n\n")
            print(f"GinEncoder weights in epoch_{epoch_i}_batch0:{next(model.graph_encoder.parameters())}")
            print("*"*80)

Outputs of codes:

NO. 0 EPOCH
MainModel weights in epoch_0_batch0:Parameter containing:
tensor([-0.1447, -0.3689, -0.2840, -0.3619, -0.2040,  0.2430,  0.4651,  0.3736],
       device='cuda:0', requires_grad=True)

GinEncoder weights in epoch_0_batch0:Parameter containing:
tensor([[-0.8312],
        [-0.5712],
        [-0.6963],
        [-0.1601]], device='cuda:0', requires_grad=True)
********************************************************************************
NO. 1 EPOCH
MainModel weights in epoch_1_batch0:Parameter containing:
tensor([-0.1842, -0.3333, -0.3170, -0.3247, -0.2424,  0.2627,  0.4272,  0.4119],
       device='cuda:0', requires_grad=True)

GinEncoder weights in epoch_1_batch0:Parameter containing:
tensor([[-0.8312],
        [-0.5712],
        [-0.6963],
        [-0.1601]], device='cuda:0', requires_grad=True)
********************************************************************************
NO. 2 EPOCH
MainModel weights in epoch_2_batch0:Parameter containing:
tensor([-0.2302, -0.3077, -0.3251, -0.2905, -0.2847,  0.2558,  0.3881,  0.4527],
       device='cuda:0', requires_grad=True)

GinEncoder weights in epoch_2_batch0:Parameter containing:
tensor([[-0.8312],
        [-0.5712],
        [-0.6963],
        [-0.1601]], device='cuda:0', requires_grad=True)
********************************************************************************

My question is how to make loss.backward() and optimizer.step() also pass to GinEncoder?

PS.

Could you check the .grad attribute of all parameters of the GinEncoder before and after the first .backward call to see if these gradients are calculated but might be small?

Thank you for your patience in reading my question.

I observe that the gradient of parameters Gin Model is always 0.

I tried to use following codes to observe the .grad attribute of the first layer of parameters of the GinEncoder:

        x, x_edge_index, x_batch_node_id = data.x, data.edge_index, data.batch
        y, y_edge_index, y_batch_node_id = data.y[-1].x, data.y[-1].edge_index, torch.zeros(data.y[-1].x.shape[0], dtype=torch.int64).to("cuda")
        model_optimizer.zero_grad()
        graph_embeds_pred = model(x, x_edge_index, x_batch_node_id)
        y_graph_embeds = model.graph_encoder.get_embeddings(y, y_edge_index, y_batch_node_id)
        loss =  criterion(graph_embeds_pred, y_graph_embeds)
        train_loss += loss
        print(f"Before loss.backward(), MainModel weights.grad in epoch_{epoch_i}_batch{batch_i}:{next(islice(model.parameters(), 15, 16)).grad}", end="\n\n")
        print(f"Before loss.backward(), MainModel.graph_encoder weights.grad in epoch_{epoch_i}_batch{batch_i}:{next(model.graph_encoder.parameters()).grad}")
        loss.backward()
        print(f"After loss.backward(), MainModel weights.grad in epoch_{epoch_i}_batch{batch_i}:{next(islice(model.parameters(), 15, 16)).grad}", end="\n\n")
        print(f"After loss.backward(), MainModel.graph_encoder weights.grad in epoch_{epoch_i}_batch{batch_i}:{next(model.graph_encoder.parameters()).grad}")
        print("*"*80)

The result in epoch 0 & batch 0:

Before loss.backward(), MainModel weights.grad in epoch_0_batch0:None

Before loss.backward(), MainModel.graph_encoder weights.grad in epoch_0_batch0:None
After loss.backward(), MainModel weights.grad in epoch_0_batch0:tensor([-0.0839,  0.0596, -0.1096,  0.0718,  0.1150,  0.0749,  0.0800,  0.0076],
       device='cuda:0')

After loss.backward(), MainModel.graph_encoder weights.grad in epoch_0_batch0:tensor([[0.],
        [0.],
        [0.],
        [0.]], device='cuda:0')

The result in epoch 0 & batch 1:

Before loss.backward(), MainModel weights.grad in epoch_0_batch1:tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')

Before loss.backward(), MainModel.graph_encoder weights.grad in epoch_0_batch1:tensor([[0.],
        [0.],
        [0.],
        [0.]], device='cuda:0')
After loss.backward(), MainModel weights.grad in epoch_0_batch1:tensor([-0.0640,  0.0315, -0.0785,  0.0666,  0.1209,  0.0641,  0.0495, -0.0090],
       device='cuda:0')

After loss.backward(), MainModel.graph_encoder weights.grad in epoch_0_batch1:tensor([[0.],
        [0.],
        [0.],
        [0.]], device='cuda:0')

OK, the results at least show that you are not detaching the operations from the computation graph since the .grad attributes are at least populated.
I don’t know enough about the GINConv implementation to comment why the gradients might be zero, but @rusty1s might know.

Thank You for reply.

I have tried to merge the two class: MainModel and GinEncoder() to avoid the unchanging parameters of gin_covs().

It works, it made the parameters of gin_convs() change, but this approach is still not answer my confusion

Here is the update codes:

class MainModel2(torch.nn.Module):
    def __init__(self):
        super(MainModel2, self).__init__()
        self.gin_convs = torch.nn.ModuleList()
        self.gin_convs.append(GINConv(Sequential(Linear(1, 4), ReLU(),
                                                 Linear(4, 4), ReLU(),
                                                 BatchNorm1d(4))))
        self.gin_convs.append(GINConv(Sequential(Linear(4, 4), ReLU(),
                                                 Linear(4, 4), ReLU(),
                                                 BatchNorm1d(4))))
        self.lin1 = Linear(8, 4)
        self.lin2 = Linear(4, 8)


    def forward(self, x, edge_index, batch_node_id):
        # Node embeddings
        nodes_emb_layers = []
        for i in range(2):
            x = self.gin_convs[i](x, edge_index)
            nodes_emb_layers.append(x)

        # Graph-level readout
        nodes_emb_pools = [global_add_pool(nodes_emb, batch_node_id) for nodes_emb in nodes_emb_layers]

        # Concatenate and form the graph embeddings
        graph_embeds = torch.cat(nodes_emb_pools, dim=1)
        out_lin1 = self.lin1(graph_embeds)
        pred = self.lin2(out_lin1)[-1]

        return pred

Update:
The the weights are actually updated, I sums up all the .grad of every layers of GinEncoder.

I find that the .grad of the first layer of GinEncoder weights occasionally is zero, it has nothing to do with how you build the model from pytorch and torch_geometric.

So like @ptrblck said, after check the .grad attribute of all parameters, the weights are actually updated, I just missed it.