A quite strange phenomenon of using quiver sampler and backward

Actually, I have already raised a similar issue at Quiver github, where Quiver is a distributed graph learning library for PyTorch Geometric. The issue link and basic backgroud can be found here: A strange phenomenon about quiver sampler · Issue #122 · quiver-team/torch-quiver · GitHub

(Please check out the above link before you read the following words.)

But the quiver author seems no idea about the reason, therefore I turn to pytorch here looking for help. Actually, the strange phenomenon of model speeding up is not only about quiver sampler, but also the autograd process. Here I paste a table about shielding some codes of model and how the training time changed.


We can see that after adding loss.backward(), the time cost decreased rapidly.

Hope to get your help, thanks!

Could you post a minimal, executable code snippet showing this behavior, please?

Ok, here I post a code example. Maybe you need to install torch-geometric and torch-quiver. Then the Reddit dataset should be downloaded automatically.

import os.path as osp
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn import SAGEConv

import quiver
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
dataset = Reddit(path)
data = dataset[0]

train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
train_loader = torch.utils.data.DataLoader(train_idx,
                                           batch_size=128,
                                           shuffle=True,
                                           drop_last=False) # Quiver
csr_topo = quiver.CSRTopo(data.edge_index) # Quiver
quiver_sampler = quiver.pyg.GraphSageSampler(csr_topo, sizes=[25, 10], device=0, mode='UVA') # Quiver

class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(SAGE, self).__init__()

        self.num_layers = 2

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        return x.log_softmax(dim=-1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SAGE(dataset.num_features, 256, dataset.num_classes)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

x = data.x.to(device) # Original PyG Code
y = data.y.squeeze().to(device)

def train(epoch, mode=0):
    model.train()
    if mode == 0:
        total_time = 0
        for seeds in train_loader:
            torch.cuda.synchronize()
            start = time.time()
            n_id, batch_size, adjs = quiver_sampler.sample(seeds) # Quiver
            adjs = [adj.to(device) for adj in adjs]
            out = model(x[n_id], adjs)
            loss = F.nll_loss(out, y[n_id[:batch_size]])
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            torch.cuda.synchronize()
            total_time += time.time() - start
        print("Epoch %s elpased total time: %s" % (epoch, total_time))
        return total_time

    elif mode == 1:
        sample_time = 0
        for seeds in train_loader:
            torch.cuda.synchronize()
            start = time.time()
            n_id, batch_size, adjs = quiver_sampler.sample(seeds) # Quiver
            adjs = [adj.to(device) for adj in adjs]
            torch.cuda.synchronize()
            sample_time += time.time() - start
        print("Epoch %s elapsed sample time: %s" % (epoch, sample_time))
        return sample_time

total_time = []
for epoch in range(0, 20):
    e_time = train(epoch, mode=0)
    total_time.append(e_time)

print("Average time: %s, num_workers: %s, batch_size: %s, hidden_size: %s" \
             % (sum(total_time[2:]) / 18, 0,
                128, 256))