Gradient Checkpointing not backpropping in GNN

I’m trying to apply gradient checkpointing to a graph neural network, with each iteration of message passing being checkpointed, but the backwards pass is completely not working. I have tried the ideas in other forum questions, such as dummy variables and forcing require_grad_(True), to no avail. The basic structure is:

class CheckGNN(nn.Module):
    Segment classification graph neural network model.
    Consists of an input network, an edge network, and a node network.
    def __init__(self, in_channels=3, hidden_dim=8, n_graph_iters=3,
                 hidden_activation=torch.nn.Tanh, layer_norm=True):
        super(CheckGNN, self).__init__()
        self.n_graph_iters = n_graph_iters
        # Setup the input network
        self.input_network = mlp(...)
        # Setup the edge network
        self.edge_network = EdgeNetwork(...)
        # Setup the node layers
        self.node_network = NodeNetwork(...)
        # Attempt at dummy input fix
        self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
    def custom_forward(self, inputs):
        # Apply edge network
        e = self.edge_network((inputs[0], inputs[1]))
        e = torch.sigmoid(e)
        # Apply node network
        x = self.node_network((inputs[0], e, inputs[1]))
        return x
    def forward(self, x, edge_index):
        """Apply forward pass of the model"""
        input_x = x
        x = self.input_network(x)
        # Shortcut connect the inputs onto the hidden representation
        x =[x, input_x], dim=-1)
        # Loop over iterations of edge and node networks
        for i in range(self.n_graph_iters):
            x_inital = x
            x = checkpoint(self.custom_forward, (x, edge_index, self.dummy_tensor))
            # Shortcut connect the inputs onto the hidden representation
            x =[x, input_x], dim=-1)  
            x = x_inital + x
        e = self.edge_network((x, edge_index))
        return e

Printing loss.grad gives None, the model weights are empty for the NodeNetwork. They exist for the EdgeNetwork but only because of the final, uncheckpointed run when it outputs the edge scores. The dummy input doesn’t change this. Am I missing something fundamental here?

P.S. I believe the details of the Edge and Node networks are irrelevant, but I can include them if needed.

Hi Daniel, I am facing the same problem. Did you find a solution?