Speed up dataloader when complex operations are needed in the __getitem__ function


I am trying to write a customized data loader to process some graph-structured data, in the getitem function, I am trying to permute the edges in a graph indexed by the “index” parameter, and then sample different groups of the edges from it, (actually in the flavor of meta-learning), however, the complex operations in it makes the data loader works very slowly, (actually during training, the volatile GPU-util is always 0). Here is the sample code for the customized data loader.

class DataGenerator(Dataset):
    """docstring for DataGenerator"""

    def __init__(self, args, graph_group, status):
        super(DataGenerator, self).__init__()
        self.args = args
        self.graphs = graph_group
        self.graph_nx = []
        self.row = []
        self.col = []
        # self.g = torch.Generator()
        # self.g.manual_seed(args.seed)
        # random.seed(args.seed)
        self.status = status
        self.rng_permutation_train = np.random.RandomState(args.seed)
        self.rng_choose_edge_train = np.random.RandomState(args.seed)
        self.rng_permutation_test = np.random.RandomState(args.seed)
        self.rng_choose_edge_test = np.random.RandomState(args.seed)
        if args.use_cross_graph_meta:
            self.rng_choose_graph_train_cross = np.random.RandomState(
            self.rng_permutation_train_cross = np.random.RandomState(args.seed)
            self.rng_choose_edge_train_cross = np.random.RandomState(args.seed)
        if status == 'train':
            self.graphs_in_use = self.graphs[:-args.test_graph_number]
        elif status == 'test':
            self.graphs_in_use = self.graphs[-args.test_graph_number:]
        for graph in self.graphs_in_use:
            graph_temp = nx.read_gpickle(graph)
                [e[0] for e in graph_temp.edges()]))
                [e[1] for e in graph_temp.edges()]))

    def __getitem__(self, index):
            pos_train_edges, neg_train_edges, pos_test_edges, neg_test_edges\
                = self.sampling_function_for_task(
                    index, self.args,
            return pos_train_edges, neg_train_edges,\
                pos_test_edges, neg_test_edges

    def __len__(self):
        return len(self.graphs_in_use)

    def sampling_function_for_task(self, index, args, num_nodes):
        row = self.row[index] 
        col = self.col[index]  

        n_v = int(math.floor(args.sample_ratio * row.size(0)))
        n_t = int(math.floor(args.query_ratio * row.size(0)))

        # Positive edges.
        perm = self.rng_permutation_train.permutation(row.size(0))
        row, col = row[perm], col[perm]

        r, c = row[:n_v], col[:n_v]
        sample_pos_edge_index = torch.stack([r, c], dim=0)
        r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
        query_pos_edge_index = torch.stack([r, c], dim=0)

        # Negative edges.
        neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)
        neg_adj_mask = neg_adj_mask.triu(diagonal=1)
        row = row.to(torch.long)
        col = row.to(torch.long)
        neg_adj_mask[row, col] = 0

        neg_row, neg_col = neg_adj_mask.nonzero().t()
        # note random.sample doesn't allow replacement.
        perm = torch.tensor(self.rng_choose_edge_train.choice(
            range(neg_row.size(0)), int(args.neg_pos_ratio) * (n_t + n_v),
        perm = perm.to(torch.long)
        neg_row, neg_col = neg_row[perm], neg_col[perm]

        neg_adj_mask[neg_row, neg_col] = 0
        # train_neg_adj_mask = neg_adj_mask

        row, col = neg_row[:int(args.neg_pos_ratio) *
                           n_v], neg_col[:int(args.neg_pos_ratio) * n_v]
        sample_neg_edge_index = torch.stack([row, col], dim=0)

        row, col = neg_row[
            int(args.neg_pos_ratio) * n_v:
            int(args.neg_pos_ratio) * (n_v + n_t)], \
            neg_col[int(args.neg_pos_ratio) *
                    n_v:int(args.neg_pos_ratio) * (n_v + n_t)]
        query_neg_edge_index = torch.stack([row, col], dim=0)

        return sample_pos_edge_index, sample_neg_edge_index, \
            query_pos_edge_index, query_neg_edge_index

I have tried some speeding techniques like adding prefetcher but it still works slowly. Anybody have suggestions on how to refactor the code block to make it more efficient?

For easier reading, here the ''graph_group" is the address list of the input graphs, and the graph is composed of edges, so the graph_temp.edges is like [(1,12),…(12345,34567)], which means node 1 and 12 are connected. They form an edge.