I’m trying to apply gradient checkpointing to a graph neural network, with each iteration of message passing being checkpointed, but the backwards pass is completely not working. I have tried the ideas in other forum questions, such as dummy variables and forcing require_grad_(True), to no avail. The basic structure is:
class CheckGNN(nn.Module):
"""
Segment classification graph neural network model.
Consists of an input network, an edge network, and a node network.
"""
def __init__(self, in_channels=3, hidden_dim=8, n_graph_iters=3,
hidden_activation=torch.nn.Tanh, layer_norm=True):
super(CheckGNN, self).__init__()
self.n_graph_iters = n_graph_iters
# Setup the input network
self.input_network = mlp(...)
# Setup the edge network
self.edge_network = EdgeNetwork(...)
# Setup the node layers
self.node_network = NodeNetwork(...)
# Attempt at dummy input fix
self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
def custom_forward(self, inputs):
# Apply edge network
e = self.edge_network((inputs[0], inputs[1]))
e = torch.sigmoid(e)
# Apply node network
x = self.node_network((inputs[0], e, inputs[1]))
return x
def forward(self, x, edge_index):
"""Apply forward pass of the model"""
input_x = x
x = self.input_network(x)
# Shortcut connect the inputs onto the hidden representation
x = torch.cat([x, input_x], dim=-1)
# Loop over iterations of edge and node networks
for i in range(self.n_graph_iters):
x_inital = x
x = checkpoint(self.custom_forward, (x, edge_index, self.dummy_tensor))
# Shortcut connect the inputs onto the hidden representation
x = torch.cat([x, input_x], dim=-1)
x = x_inital + x
e = self.edge_network((x, edge_index))
return e
Printing loss.grad
gives None
, the model weights are empty for the NodeNetwork
. They exist for the EdgeNetwork
but only because of the final, uncheckpointed run when it outputs the edge scores. The dummy input doesn’t change this. Am I missing something fundamental here?
P.S. I believe the details of the Edge and Node networks are irrelevant, but I can include them if needed.