Heterogeneous Graphs Graph Embedding (global_mean_pool)

I’m working with PyTorch-geometrics and heterogeneous graphs.
How can I aggregate the node embeddings by node types to receive the overall graph embedding?

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)

        return global_mean_pool(x, batch=None)

model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

model(data.x_dict, data.edge_index_dict)


{'node_type_1': tensor([[-0.0687, -0.2087]], grad_fn=<MeanBackward1>),
 'node_type_2': tensor([[ 0.0210, -0.0251]], grad_fn=<MeanBackward1>),
 'node_type_3': tensor([[ 0.1923, -0.0379]], grad_fn=<MeanBackward1>),
 'node_type_4': tensor([[-0.0792,  0.1417]], grad_fn=<MeanBackward1>)}