Can't split HeteroData with RandomNodeSplit

import networkx as nx
import pandas as pd
from collections import defaultdict
import torch
from import Data, HeteroData
import gensim.downloader as api
from gensim.models import Word2Vec
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torch_geometric.loader import NeighborLoader
import torch_geometric.transforms as T
from torch_geometric.transforms import RandomNodeSplit

Load pre-trained Word2Vec model

wv_model = api.load(“glove-wiki-gigaword-50”)

def encode_text(text):
“”“Encode text using Word2Vec or GloVe”“”
words = text.split()
embeddings =
for word in words:
embedding = wv_model.get_vector(word)
except KeyError:
# If the word is not in the vocabulary, skip it
if len(embeddings) == 0:
# If no words are found, return a random vector
return torch.randn(wv_model.vector_size)
return torch.mean(torch.tensor(embeddings), dim=0)

scaler = StandardScaler()
encoder = OneHotEncoder(sparse=False)

def preprocess_data(data):
“”“Preprocess the data by handling missing values, splitting categories, and encoding features.”“”
data = data.fillna({‘text’: ‘’, ‘categories’: ‘’})
data[‘categories’] = data[‘categories’].str.split(', ')
# Scale numerical features
data = scale_numerical_features(data, [‘longitude’, ‘latitude’, ‘review_count_x’, ‘review_count’])
# Encode categorical features
data = encode_categorical_features(data, [‘city’])
return data

def scale_numerical_features(data, features):
“”“Scale numerical features using StandardScaler”“”
scaled_data = data.copy()
scaled_data[features] = scaler.fit_transform(scaled_data[features])
return scaled_data

def encode_categorical_features(data, features):
“”“Encode categorical features using OneHotEncoder”“”
encoded_data = data.copy()
encoded_data[features] = encoder.fit_transform(encoded_data[features])
return encoded_data

def create_bipartite_graphs(data):
“”“Create user-restaurant and user-category bipartite graphs.”“”
user_restaurant_graph = nx.Graph()
user_category_graph = nx.Graph()

# Unique nodes
unique_users = data['user_id'].unique()
unique_restaurants = data['business_id'].unique()
unique_categories = set(category for categories in data['categories'] for category in categories)

# Add nodes to user-restaurant graph
user_restaurant_graph.add_nodes_from(unique_users, bipartite='user')
user_restaurant_graph.add_nodes_from(unique_restaurants, bipartite='restaurant')

# Add nodes to user-category graph
user_category_graph.add_nodes_from(unique_users, bipartite='user')
user_category_graph.add_nodes_from(unique_categories, bipartite='category')

# Add edges to user-restaurant graph
user_restaurant_edges = [(row['user_id'], row['business_id']) for _, row in data.iterrows()]

# Add 'stars_x' as a weighted edge attribute to user-restaurant graph
star_x_edges = [(row['user_id'], row['business_id'], {'star_x': row['stars_x']}) for _, row in data.iterrows()]

# Add edges to user-category graph
user_category_edges = [(row['user_id'], category) for _, row in data.iterrows() for category in row['categories']]

# Add 'stars_y' as a weighted edge attribute to user-category graph
star_y_edges = [(row['user_id'], category, {'star_y': row['stars_y']}) for _, row in data.iterrows() for category in row['categories']]

return user_restaurant_graph, user_category_graph

def add_node_features(user_restaurant_graph, user_category_graph, data):
“”“Add node features to the bipartite graphs.”“”
# Add ‘review_count’ as a node feature
user_review_count = dict(zip(data[‘user_id’], data[‘review_count_x’]))
restaurant_review_count = dict(zip(data[‘business_id’], data[‘review_count’]))
nx.set_node_attributes(user_restaurant_graph, user_review_count, name=‘review_count_user’)
nx.set_node_attributes(user_restaurant_graph, restaurant_review_count, name=‘review_count_restaurant’)

# Add 'city' as a node feature
restaurant_city = dict(zip(data['business_id'], data['city']))
nx.set_node_attributes(user_restaurant_graph, restaurant_city, name='city_restaurant')

# Add 'longitude' and 'latitude' as node features for restaurants
restaurant_longitude = dict(zip(data['business_id'], data['longitude']))
restaurant_latitude = dict(zip(data['business_id'], data['latitude']))
nx.set_node_attributes(user_restaurant_graph, restaurant_longitude, name='longitude_restaurant')
nx.set_node_attributes(user_restaurant_graph, restaurant_latitude, name='latitude_restaurant')

# Add categories as a node feature for users
user_categories = defaultdict(list)
for _, row in data.iterrows():
nx.set_node_attributes(user_category_graph, {user: list(set(categories)) for user, categories in user_categories.items()}, name='categories_user')

def add_edge_features(user_restaurant_graph, user_category_graph, data):
“”“Add edge features to the bipartite graphs.”“”
# Add ‘text’ as an edge feature
user_restaurant_text = {(row[‘user_id’], row[‘business_id’]): row[‘text’] for _, row in data.iterrows()}
user_category_text = {(row[‘user_id’], category): row[‘text’] for _, row in data.iterrows() for category in row[‘categories’]}
nx.set_edge_attributes(user_restaurant_graph, user_restaurant_text, name=‘review_text’)
nx.set_edge_attributes(user_category_graph, user_category_text, name=‘review_text’)

# Add 'name' and 'categories' as edge features for user-restaurant graph
user_restaurant_name = {(row['user_id'], row['business_id']): row['name'] for _, row in data.iterrows()}
user_restaurant_categories = {(row['user_id'], row['business_id']): row['categories'] for _, row in data.iterrows()}
nx.set_edge_attributes(user_restaurant_graph, user_restaurant_name, name='restaurant_name')
nx.set_edge_attributes(user_restaurant_graph, user_restaurant_categories, name='restaurant_categories')
# Add 'star_x' as an edge feature with weights
star_x_edges = [(row['user_id'], row['business_id'], {'star_x': row['stars_x']}) for _, row in data.iterrows()]
# Add 'stars_y' as a weighted edge attribute to user-category graph
star_y_edges = [(row['user_id'], category, {'star_y': row['stars_y']}) for _, row in data.iterrows() for category in row['categories']]

def encode_node_features(node_data):
“”“Encode node features using node attributes and review text.”“”
#print(f"Node data: {node_data}") # Add this line
node_attributes =

# Encode categorical attributes
city = node_data.get('city', '')
if city:
    city_encoding = encoder.transform([[city]]).toarray().flatten()
    node_attributes.extend([0] * len(encoder.categories_[0]))

# Encode numerical attributes
review_count = node_data.get('review_count', 0)

# Encode boolean attributes
is_open = node_data.get('is_open', 0)

# Encode review text
review_text = encode_text(node_data.get('review_text', ''))

return torch.tensor(node_attributes)

def create_pyg_data(graph, node_type1, node_type2, node_type3=None):
“”“Create a PyG Data or HeteroData object from a bipartite graph.”“”
data = HeteroData()

# Add node features for node_type1
node_type1_features = []
for _, node_data in graph.nodes(data=True):
    if node_data.get('bipartite') == node_type1:
data[node_type1].x = torch.stack(node_type1_features)

# Add node features for node_type2
node_type2_features = []
for _, node_data in graph.nodes(data=True):
    if node_data.get('bipartite') == node_type2:
data[node_type2].x = torch.stack(node_type2_features)

# Add node features for node_type3 (if provided)
if node_type3 is not None:
    node_type3_features = []
    for _, node_data in graph.nodes(data=True):
        if node_data.get('bipartite') == node_type3:
    data[node_type3].x = torch.stack(node_type3_features)

# Add edge indices
node_type_mapping = {'user': 0, 'restaurant': 1, 'category': 2}
edges = [(node_type_mapping[graph.nodes[u]['bipartite']], node_type_mapping[graph.nodes[v]['bipartite']]) for u, v in graph.edges()]
row, col = zip(*edges)
data[(node_type1, 'to', node_type2)].edge_index = torch.tensor([row, col], dtype=torch.long)
if node_type3 is not None:
    data[(node_type1, 'to', node_type3)].edge_index = torch.tensor([row, col], dtype=torch.long)

# Add edge features
edge_features = [encode_text(graph.edges[u, v]['review_text']) for u, v in graph.edges()]
data[(node_type1, 'to', node_type2)].edge_attr = torch.stack(edge_features)
if node_type3 is not None:
    data[(node_type1, 'to', node_type3)].edge_attr = torch.stack(edge_features)

# Add 'star_x' as an edge feature for user-restaurant graph
user_restaurant_star_x = [graph.edges[u, v].get('star_x', 0) for u, v in graph.edges() if graph.nodes[u]['bipartite'] == node_type1 and graph.nodes[v]['bipartite'] == node_type2]
data[(node_type1, 'to', node_type2)].edge_weight = torch.tensor(user_restaurant_star_x)

# Add 'star_y' as an edge feature for user-category graph
if node_type3 is not None:
    user_category_star_y = [graph.edges[u, v].get('star_y', 0) for u, v in graph.edges() if graph.nodes[u]['bipartite'] == node_type1 and graph.nodes[v]['bipartite'] == node_type3]
    data[(node_type1, 'to', node_type3)].edge_weight = torch.tensor(user_category_star_y)

return data

def create_combined_data(user_restaurant_graph, user_category_graph, preprocessed_data):
“”“Create the combined HeteroData object.”“”
# Add node features
add_node_features(user_restaurant_graph, user_category_graph, preprocessed_data)

# Add edge features
add_edge_features(user_restaurant_graph, user_category_graph, preprocessed_data)

# Create PyG Data objects
user_restaurant_data = create_pyg_data(user_restaurant_graph, 'user', 'restaurant')
user_category_data = create_pyg_data(user_category_graph, 'user', 'category')

# Create PyG homogeneous Data object (optional)
combined_graph = nx.compose(user_restaurant_graph, user_category_graph)
combined_data = create_pyg_data(combined_graph, 'user', 'restaurant', 'category')
# Split the combined_data into train, validation, and test sets based on 'user' node type
combined_data = RandomNodeSplit(split='train_rest', num_val=0.1, num_test=0.1)(combined_data)
combined_data = T.ToUndirected()(combined_data)
print(f'data: {combined_data}')

return combined_data, combined_graph

def main():
# Load the dataset
data = pd.read_csv(“/kaggle/working/subset_review_merged1.csv”)

# Preprocess the data
preprocessed_data = preprocess_data(data)
# Create bipartite graphs
user_restaurant_graph, user_category_graph = create_bipartite_graphs(preprocessed_data)

# Create the combined HeteroData object and the combined graph
combined_data, combined_graph = create_combined_data(user_restaurant_graph, user_category_graph, preprocessed_data)

return combined_graph, combined_data

# Define the is_connected function
def is_connected(graph):
    """Check if a graph is connected using BFS."""
    # Initialize visited set and queue for BFS
    visited = set()
    queue = []

    # Choose an arbitrary starting node
    start_node = next(iter(graph.nodes()))

    # Start BFS from the chosen node

    while queue:
        # Dequeue a node from the queue
        node = queue.pop(0)

        # Visit neighbors of the current node
        for neighbor in graph.neighbors(node):
            if neighbor not in visited:

    # If all nodes are visited, the graph is connected
    return len(visited) == len(graph.nodes())

# Check if the combined graph is connected
combined_connected = is_connected(combined_graph)
print("Is the combined graph connected?", combined_connected)

# Check if the combined graph is bipartite
combined_bipartite = nx.is_bipartite(combined_graph)

# Visualization
plt.figure(figsize=(10, 6))
nx.draw(combined_graph, with_labels=True)
plt.title(f'Combined Bipartite Graph\nBipartite: {combined_bipartite}')

return combined_graph, combined_data

if name == ‘main’:
combined_graph, combined_data = main()

# Access edge attributes for ('user', 'to', 'restaurant') edge type
user_restaurant_edge_index = combined_data[('user', 'to', 'restaurant')].edge_index
user_restaurant_edge_attr = combined_data[('user', 'to', 'restaurant')].edge_attr
user_restaurant_edge_weight = combined_data[('user', 'to', 'restaurant')].edge_weight

# Access edge attributes for ('user', 'to', 'category') edge type
user_category_edge_index = combined_data[('user', 'to', 'category')].edge_index
user_category_edge_attr = combined_data[('user', 'to', 'category')].edge_attr
user_category_edge_weight = combined_data[('user', 'to', 'category')].edge_weight
# Print the dimensions of edge attributes for each edge type
print(f"Shape of data['user'].x: {combined_data['user'].x.shape}")

print("User to Restaurant Edge Index shape:", user_restaurant_edge_index.shape)
print("User to Restaurant Edge Attribute shape:", user_restaurant_edge_attr.shape)

This is the combined_data even after the split:
data: HeteroData(
user={ x=[583, 53] },
restaurant={ x=[142, 53] },
category={ x=[122, 53] },
(user, to, restaurant)={
edge_index=[2, 3669],
edge_attr=[3669, 50],
(user, to, category)={
edge_index=[2, 3669],
edge_attr=[3669, 50],
(restaurant, rev_to, user)={
edge_index=[2, 3669],
edge_attr=[3669, 50]
(category, rev_to, user)={
edge_index=[2, 3669],
edge_attr=[3669, 50]

My issues is why after using the RandomNodeSplit, it didn’t split the data to train, val and test_masks… So is there anything i’m doing wrong?