I am trying to implement an MPNN (pseudo-code below) on relatively large graphs (graphs w/ > 2000 nodes) to do per-node classification.
Because of the nature of my MPNN, I end up passing an entire graph through a series of networks. The output of these networks is a classification for each node in the graph, as mentioned. Each classification depends on some set of surrounding nodes. I update node states through NNs and eventually get a classification for that node.
My concern comes from using autograd and computing the loss.
When I compute the loss. It is being computed for all nodes in the graph. Then I run optimizer.step() to update my weights. The problem I see with this is that this is the equivalent of passing in a batch of size > 2000. And from here that seems bad:
Some pseudo code looks like this:
for each graph: for each timestep: for each node: compute new node state based upon neighbors node state using a neural network for each node: classify node based on final node state using neural network (different nn from the first mentioned) loss = compute loss for all nodes loss.backward() optimizer.step()
Am I right in my understanding that this is a bad way to go about updating my weights? Is there a way in pytorch to compute the loss on only a small subset of the graph and update the weights using only the gradients computed on that subset? Something like:
for each graph: for each timestep: for each node: compute new node state based upon neighbors node state using a neural network for each node: classify node based on final node state using neural network (different nn from the first mentioned) #Is this valid? loss = compute loss for subset of nodes loss.backward() optimizer.step()
I still need all the nodes of the graph as input to compute the classifications for any given node (due to the message passing) so I cannot simply remove them from the graph.