Is it viable to update weights using the gradient for subset of output

I am trying to implement an MPNN (pseudo-code below) on relatively large graphs (graphs w/ > 2000 nodes) to do per-node classification.

Because of the nature of my MPNN, I end up passing an entire graph through a series of networks. The output of these networks is a classification for each node in the graph, as mentioned. Each classification depends on some set of surrounding nodes. I update node states through NNs and eventually get a classification for that node.

My concern comes from using autograd and computing the loss.

When I compute the loss. It is being computed for all nodes in the graph. Then I run optimizer.step() to update my weights. The problem I see with this is that this is the equivalent of passing in a batch of size > 2000. And from here that seems bad:

Some pseudo code looks like this:

for each graph:
    for each timestep:
      for each node:
          compute new node state based upon neighbors node state using a neural network
   for each node:
      classify node based on final node state using neural network (different nn from the first mentioned)
  
   loss = compute loss for all nodes
   loss.backward()
   optimizer.step()

Am I right in my understanding that this is a bad way to go about updating my weights? Is there a way in pytorch to compute the loss on only a small subset of the graph and update the weights using only the gradients computed on that subset? Something like:

for each graph:
    for each timestep:
      for each node:
          compute new node state based upon neighbors node state using a neural network
   for each node:
      classify node based on final node state using neural network (different nn from the first mentioned)
  
   #Is this valid?
   loss = compute loss for subset of nodes
   loss.backward()
   optimizer.step()

I still need all the nodes of the graph as input to compute the classifications for any given node (due to the message passing) so I cannot simply remove them from the graph.

Here are some of the pointers that I could think of, just to start with:

  1. you can use torch.autograd.grad() for calculating gradient for subset of nodes. (or)

  2. When you want to update only certain nodes, freeze the params of other nodes, do a backprop and update the weights.