How to mask edge_index with train_mask

How can I mask the edge_index so that it only contains edge connections of train nodes before passing in to a GCN model, in order to prevent data leakage between test nodes and train nodes.

import torch
import torch_geometric
from torch_geometric.nn import GCNConv

# Assuming you have a graph with node features and training labels
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
y = torch.tensor([1,1,1], dtype=torch.long)

train_mask = [True, True, False] # Boolean mask indicating the training nodes

data =, y=y, edge_index=edge_index)
# Apply mask to input data
x_train = data.x[train_mask]
edge_index_train = **???**

# Create the GCNConv layer
conv = GCNConv(in_channels, out_channels)
# Perform forward pass only on training nodes
out = conv(x_train, edge_index_train)

# Perform further computations with the output

For e.g. How to replace “???” in above code?