Long time spent in backward() when inputs are sparse tensors


I’ll start with saying that we have a quite atypical use case so I’ll present some background. We’re using pytorch to develop a GPU accelerated trainer for neural networks used by the chess engine Stockfish. Our networks have a large fully connected input layer (with sparse features), but the network is shallow and the later layers are small (we require the unbatched inference speed to be >1M/s/core on a CPU). Because of this we’re forced to make some stuff differently:

  • We are using sparse tensors as inputs. The density is about 0.1%. Moreover we consider the batch dimension as sparse, so sparse_dim=2, shape is (batch_size, 41024).
  • We’re processing hundreds of thousands of predictions/examples per second.
  • We cannot use the DataLoader because we need fast data loader that’s implemented in C++ and the concept of the DataLoader is fundamentally incompatibile with the way we pass tensors and form batches.
  • Our batches are not divisible.
  • We cannot use DataParallel nor DistributedDataParallel for the above reasons (well, we got it kindof working with DataParallel after a lot of struggles but it breaks the nvidia driver, hangs the machine, and requires a reboot).

Therefore we decided to implement multigpu support with the simplest way possible. This turned out to be fairly straightforward. We manually replicate the model across devices, run forward on a single batch on each of them, compute loss on each of them, run backward on each of them, accumulate gradients on the main device, and perform an optimization step. Works great. However, when porting the solution from a playground to our actual trainer we noticed a problem - performance didn’t scale for multiple gpus. After some digging we identified the issue stems from nothing else but the sparse tensors (it’s been causing issues from the start, I hope that this also answers the “what do you need sparse tensors for” that I see in every stagnated issue about them).

The problem is that for sparse tensors loss.backward() takes a huge amount of time, but our single-process multigpu training relies on asynchronicity of forward and backward calls (which works perfectly with dense inputs).

We created a self contained script presenting the issue (batch size may need adjustment to see the problem on different machines).

import torch
from torch import nn
import copy
import time

def test(batch_size, devices, sparse, warmup=False):
    # Seed the rng to have deterministic tests

    print('Devices: ' + str(devices), 'Sparse: ' + str(sparse), 'Warmup: ' + str(warmup))

    # For some reason MSE loss requires very low lr otherwise it blows up
    learning_rate = 0.001

    # Whatever arch
    model = nn.Sequential(
        nn.Linear(512, 512),
        nn.Linear(512, 512),
        nn.Linear(512, 1)

    # Whatever loss
    loss_fn = nn.MSELoss()

    # Whatever optimizer
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)

    # 0. We have 1 model, N devices, N batches, N outcome tensors
    def step(model, batches, outcomes, devices):
        # 1. Replicate the model to all devices
        local_models = [model] + [copy.deepcopy(model).to(device=device) for device in devices[1:]]

        # 2. Make each model do forward on 1 batch -> N x forward
        t0 = time.clock()
        outs = [m(batch.to(device=device, non_blocking=True)) for batch, m, device in zip(batches, local_models, devices)]
        t1 = time.clock()
        t2 = time.clock()
        if not warmup:
            print('forward  {:6.3f} seconds'.format(t1-t0))
            print('sync     {:6.3f} seconds'.format(t2-t1))

        # 3. Compute loss for each separate forward -> N losses
        losses = [loss_fn(out, outcome.to(device=device, non_blocking=True)) for outcome, out, device in zip(outcomes, outs, devices)]

        # 4. Remove gradients from all parameters. This has to be done before backwards.
        #    This should be better than zero_grad because it doesn't block and makes
        #    the first backward pass assign instead of add - less memory usage
        for m in local_models:
            for param in m.parameters():
                param.grad = None

        # 5. Do backward for each loss separately. This *should* not block
        t0 = time.clock()
        for loss in losses:
        t1 = time.clock()
        t2 = time.clock()
        if not warmup:
            print('backward {:6.3f} seconds'.format(t1-t0))
            print('sync     {:6.3f} seconds'.format(t2-t1))

        # 6. Non blocking transfer of all gradients to the main device
        #    This shouldn't be that much data for our small net
        grads_by_model = [[param.grad.to(device=devices[0], non_blocking=True) for param in m.parameters()] for m in local_models[1:]]

        # 7. Accumualate gradients. We don't want to average them because we're not
        #    splitting the batch, we're taking multiple batches in one step.
        for grads in grads_by_model:
            for main_param, grad in zip(model.parameters(), grads):
                main_param.grad += grad

        # 8. Optimizer runs with the accumulated gradients on the main model only.

        # Return loss for diagnostic
        return sum(loss.item() for loss in losses) / len(losses)

    # Random batches and outcomes. We don't care whether they are different for each iteration
    # so we do it once because it's faster.
    # Note that we're scaling the batch size by the number of devices so that
    # it's transparent to the user.
    batches = [(torch.rand(batch_size // len(devices), 512) * 100.0 - 99.0).clamp(0.0, 1.0).to(device=device, non_blocking=True) for device in devices]
    if sparse:
        batches = [b.to_sparse() for b in batches]

    outcomes = [torch.rand(batch_size // len(devices), 1).to(device=device, non_blocking=True) for device in devices]

    start_time = time.clock()

    losses = []
    # We do a fixed number of batch_size chunks, as the user expects
    for i in range(10):
        losses.append(step(model, batches, outcomes, devices))

    # Ensure everything completed before measuring time

    end_time = time.clock()
    if not warmup:
        print('{:6.3f} seconds'.format(end_time-start_time))
        print('Loss went from {} to {}'.format(losses[0], losses[-1]))

batch_size = 2**15
# warmup
test(batch_size, ['cuda:0'], sparse=False, warmup=True)
test(batch_size, ['cuda:0'], sparse=False)

# warmup
test(batch_size, ['cuda:0'], sparse=True, warmup=True)
test(batch_size, ['cuda:0'], sparse=True)

The behaviour we observe is that for dense inputs forward and backward execute asynchronously and take almost 0 time, all time is spent on sync, which is what we want because it allows scheduling multiple forward/backward in parallel.
But for sparse=True we observe that backward is taking much longer, most of the time longer than the subsequent sync. This completely defeats the gains from our multigpu setup. (For our real case we also observe similar high time usage for forward, though it’s not visible in this example).
(Results from my GTX750: https://pastebin.com/TfGiuWKT)

Our questions are:

  1. Why does backward take so much time when the inputs were sparse tensors?
  2. Is this a bug in pytorch?
  3. How can we work around this? Preferably without spawning multiple processes.

Is it perphaps this https://github.com/pytorch/pytorch/blob/cfe3defd88b43ba710dd1093e382c5e8c279bd83/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu#L147 that we’re seeing take up time in forward in our original case? This would partially explain why we only see backward taking longer in the toy example (because to_sparse returns a coalesced tensor?). But it rises a question, if this is indeed the issue, then why is the tensor not coalesced during backward?