Pytorch geometric: get data for individual graphs in a batch

I’m new to PyTorch Geometric. I’m processing data in batch and for each batch I forward the data through several layers and finally get w_att (attention weight matrix) of dimension NxN, with N being the total number of nodes of all the graphs in the batch. (Eg, If my batch has 2 graphs that have 3 and 4 nodes, respectively, then N = 3+4=7). The problem is, I now need to get the attention weight matrix for each graph (size num_nodes x num_nodes) individually from w_att. (Using my same example, I need 2 separate attention weight matrices; the first matrix should have dimension 3x3 and the second should have dimension 4x4). What is a good way to obtain these matrices from the encompassing w_att?