Some layers' parameters not changing, while others are learning well

I’m trying to train a graph neural network for graph classification using PyTorch Geometric. I am feeding in graphs of varying size and classifying them as either 0 or 1. Note that I do not use an activation function on the output layer, as I am also running experiments with more than two classes. The classes are intended to be ordinal, hence why I am not using a two-neuron output.

I’m finding that while some layers show changing gradients and parameters across training, others barely change from their initial values. In particular, I’m finding that the biases change significantly, while the weights in the message-passing phase barely change at all.

I’ve been measuring this change visually using the Weights & Biases tracking screen (I’ve included some images to show this). Changing from ReLU activations to ELU ones helped somewhat, indicating that the model was suffering from the dying ReLU problem, though the problem wasn’t entirely fixed.

The gradients and parameters of the weights and biases on a particular layer in the message-passing phase:

My message-passing code:

class PaliwalMP(MessagePassing):
    """Define the message-passing scheme from Subgraph Pooling paper."""
    def __init__(self, in_channels, out_channels):
        super(PaliwalMP, self).__init__(aggr='mean', flow='source_to_target') #  "Mean" aggregation.
        # MLP for Parents and Children, step 2 of Paliwal MP
        self.MLP_edge = BuildingBlock(3*in_channels, in_channels)
        self.MLP_edge_hat = BuildingBlock(3*in_channels, in_channels)
        # MLP to pass aggregated message through, step 3 of Paliwal MP
        self.MLP_aggr = BuildingBlock(3*in_channels, in_channels)

    def forward(self, x, edge_index_parents, edge_index_children, edge_attr_parents, edge_attr_children):
        # x has shape [N, in_channels]
        # edge_index_x has shape [2, E/2]
        out_parents = self.propagate(edge_index_parents, 
        out_children = self.propagate(edge_index_children, 
        out =[x, out_parents, out_children], dim=1)
        out = self.MLP_aggr(out) + x
        return out

    def message(self, x_i, x_j, edge_attr, direction):

        s_ij =[x_i, x_j, edge_attr], dim=1)
        if direction == 'up':
            s_ij = self.MLP_edge(s_ij)
        elif direction == 'down':
            s_ij = self.MLP_edge_hat(s_ij)
        return s_ij

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]
        return aggr_out

And my GNN code (higher-level, includes layers of the message-passing class):

class BuildingBlock(torch.nn.Module):
    Standard MLP scheme for use in Subgraph Pooling model
    def __init__(self, in_channels, out_channels):#, dim=0):
        super(BuildingBlock, self).__init__()
        self.lin1 = Linear(in_channels, 128)
#         self.hidden = Linear(256, 128)
        self.lin2 = Linear(128, 128)
    def forward(self, x):
        x = F.elu(self.lin1(x))
#         x = F.dropout(x, 0.3)
#         x = F.elu(self.hidden(x))
        x = F.dropout(x, 0.3)
        x = F.elu(self.lin2(x))
        return x

class PaliwalNet(torch.nn.Module):
    Implement GNN from Subgraph Pooling model. Accepts an arbitrary-size graph and produces a scalar output value.
    def __init__(self, t, no_upsample=False):
        super(PaliwalNet, self).__init__()
        self.no_upsample = no_upsample
        self.embedding = Embedding(num_embeddings=distinct_features[dataset_name]+1, embedding_dim=embed_dim)
        self.MLP_E = BuildingBlock(1, 128)
        self.message_passing_steps = nn.ModuleList()
        for i in range(t):
            self.message_passing_steps.append(PaliwalMP(embed_dim, embed_dim))

        self.conv1 = nn.Conv1d(128, 512, 1)
        self.conv2 = nn.Conv1d(512, 1024, 1)

        # TODO: Try removing some layers or adding batch norm
        self.lin1 = Linear(1024, 512)
        self.lin2 = Linear(512, 512)
        self.lin3 = Linear(512, 256)
        self.lin4 = Linear(256, 256)
        self.lin5 = Linear(256, 128)
        self.lin6 = Linear(128, 1)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        edge_index_u, edge_index_d = torch.split(edge_index, int(edge_index.shape[1]/2), dim=1)
        edge_attr_u, edge_attr_d = torch.split(edge_attr, int(edge_attr.shape[0]/2))
        # Generate learnable embeddings for node features
        x = x.squeeze(-1)
        x = self.embedding(x)
        # Embed node and edge features into high dimensional space
        if len(self.message_passing_steps) > 0:
            edge_attr_u = self.MLP_E(edge_attr_u.float())
            edge_attr_d = self.MLP_E(edge_attr_d.float())
        for i, message_passing_step in enumerate(self.message_passing_steps):
            x = message_passing_step(x, edge_index_u, edge_index_d, edge_attr_u, edge_attr_d)
        x = x.transpose(0,1).unsqueeze(0)
        x = F.elu(self.conv1(x))
        x = F.elu(self.conv2(x))
        x = x.squeeze(0).transpose(0,1)
        # Final prediction network
        g = gmp(x, batch)
        g = F.elu(self.lin1(g))
        g = F.elu(self.lin2(g))
        g = F.elu(self.lin3(g))
        g = F.elu(self.lin4(g))
        g = F.elu(self.lin5(g))
        g = self.lin6(g)
        return g

Any help or insight into what might be going on here would be much appreciated, I feel like I’ve been smacking my head into a brick wall with this problem for several weeks now!