How to get individual graph outputs when using batched graph data and custom forward function

I have a question about how pytorch geometric handles predictions via the model’s forward function for batched data. Here is my setup:

This is my batched graph data, and the main thing to focus on for thid example is the x_atm component.

Graph_DataBatch(atoms=[8000], edge_index_G=[2, 206578], edge_index_A=[2, 1249394], x_atm=[9600000, 1], x_atm_batch=[9600000], x_atm_ptr=[8001], x_bnd=[206578], x_bnd_batch=[206578], x_bnd_ptr=[8001], x_ang=[1249394], x_ang_batch=[1249394], x_ang_ptr=[8001], mask_dih_ang=[1249394], atm_amounts=[0], bnd_amounts=[0], ang_amounts=[0], y=[413156, 1])

This is fine as far as I can tell. Here is my model’s forward function:

    def forward(self, data):
        data = self.encoder(data)
        data = self.processor(data)
        return self.decoder(data)

The encoder and processor aren’t that important to my question, but the decoder is. Here is the decoder:

class PositiveScalarsDecoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.transform_atm = nn.Sequential(MLP([dim, dim, 1], act=nn.SiLU()), nn.Softplus())
        self.transform_bnd = nn.Sequential(MLP([dim, dim, 1], act=nn.SiLU()), nn.Softplus())
        self.transform_ang = nn.Sequential(MLP([dim, dim, 1], act=nn.SiLU()), nn.Softplus())

    def forward(self, data):
        atm_scalars = self.transform_atm(data.h_atm)
        bnd_scalars = self.transform_bnd(data.h_bnd)
        ang_scalars = self.transform_bnd(data.h_ang)
        return [atm_scalars, bnd_scalars, ang_scalars]

Here, my final scalars represent physical properties of our system, which are the atoms, the bond between them, and the bond angles between them. By setting up the decoder in this way we can use the final weights in each tensor group to perform feature ranking on each physical property.

Now, here is where I’m confused. The decoder, which is being passed a batch of graph data, outputs those final scalars in terms of the batched data. Here’s is an example of the output from atm_scalars = self.transform_atm(data.h_atm).

tensor([[0.6404],
        [0.6404],
        [0.6404],
        ...,
        [0.6404],
        [0.6404],
        [0.6404]], device='cuda:0', grad_fn=<SoftplusBackward0>)

From the batched graph data, we see that the x_atm attribute is x_atm=[9600000, 1], and the tensor atm_scalars also has a size of 9600000. My understanding is that this is the correct behavior of batching graph data since the batched data is essentially a giant disconnected graph (is this true?).

My problem is that this isn’t really what I want. What I want is for the final predicted tensor to have the predicted scalars grouped by the corresponding batched graph because I’m using those scalars to rank the graph features.

For example, for the case shown above each individual graph has x_atm=[1200, 1], and with a batch size of 8000 we get the 9600000. But this is just creating a single graph with 9600000 x_atm features, correct? What I want is for there to be 8000 separate graphs with each x_atm attribute to have a size of 1200 and my final output to be grouped accordingly, 8000 separate predictions of size 1200 each.

Is there a way I can do this simply without adding a lot of bookkeeping of feature lengths and manually breaking up the predicted values into their corresponding graph after the forward function is called?

Hopefully my question makes sense, but if not I’m happy to try and clarify anything.