How to define train_mask, val_mask, test_mask, ... in my own dataset?

Compiling the information in this thread: to create a mask for a custom training set you have to A). define the mask B). extend the Data attribute C). collate the Data objects into a Dataset.

data_point = torch_geometric.data.Data(x=x, edge_index=edge_index, y=y)
data_point.train_mask = torch.Tensor([...], dtype=torch.bool)  // [...] is of length y
training_data.append(data_point)  // repeat lines 1-3
loader = DataLoader(training_data, batch_size=32) 

Hope this helps everyone!

2 Likes