How to pass batch of Graphs in GATv2Conv?

I have a batched input to GATv2Conv with node matrix of shape [batch_sz , num_nodes , node_feature_dim] , but the GATv2Conv accepts input of dim 2 ,searching through the internet , I found some solution … (not the one I want)

data_list = [Data(x= torch.squeeze(torch.index_select(x, dim= 0, index = torch.tensor([idx]))) , 
                              edge_index= self.edge_indices ,
                              edge_attr= torch.squeeze(torch.index_select(adj_mats[i], dim= 0, index = torch.tensor([idx])))) 
                         for idx in range(self.batch_sz)] 
            batch  = Batch.from_data_list(data_list)

But using above solution , the distinction between graphs got lost , becuase :
batch.x.shape gave [batch_sz * num_nodes , node_feature_dim]

It simply put all nodes of all graphs in one single graphs… Now there is shared calculations between different graphs , which is strictly undesirable… As when applying some graph pooling layer, I don’t know which nodes belonged to which graph…

Pls suggest some fix for this issue …

Thanks in advance