I’m trying to apply gradient checkpointing to a graph neural network, with each iteration of message passing being checkpointed, but the backwards pass is completely not working. I have tried the ideas in other forum questions, such as dummy variables and forcing require_grad_(True), to no avail. The basic structure is:
class CheckGNN(nn.Module): """ Segment classification graph neural network model. Consists of an input network, an edge network, and a node network. """ def __init__(self, in_channels=3, hidden_dim=8, n_graph_iters=3, hidden_activation=torch.nn.Tanh, layer_norm=True): super(CheckGNN, self).__init__() self.n_graph_iters = n_graph_iters # Setup the input network self.input_network = mlp(...) # Setup the edge network self.edge_network = EdgeNetwork(...) # Setup the node layers self.node_network = NodeNetwork(...) # Attempt at dummy input fix self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True) def custom_forward(self, inputs): # Apply edge network e = self.edge_network((inputs, inputs)) e = torch.sigmoid(e) # Apply node network x = self.node_network((inputs, e, inputs)) return x def forward(self, x, edge_index): """Apply forward pass of the model""" input_x = x x = self.input_network(x) # Shortcut connect the inputs onto the hidden representation x = torch.cat([x, input_x], dim=-1) # Loop over iterations of edge and node networks for i in range(self.n_graph_iters): x_inital = x x = checkpoint(self.custom_forward, (x, edge_index, self.dummy_tensor)) # Shortcut connect the inputs onto the hidden representation x = torch.cat([x, input_x], dim=-1) x = x_inital + x e = self.edge_network((x, edge_index)) return e
None, the model weights are empty for the
NodeNetwork. They exist for the
EdgeNetwork but only because of the final, uncheckpointed run when it outputs the edge scores. The dummy input doesn’t change this. Am I missing something fundamental here?
P.S. I believe the details of the Edge and Node networks are irrelevant, but I can include them if needed.