Gathering batches efficiently

I have a fairly standard workflow with sequence data (size (N, C, T) representing a single example) where N can vary in size. I have written a keyed batching pipeline that will gather up examples so every forwards on the model will run with a fixed batch size (B, C, T). Many examples might fit into a single batch or one example might span multiple batches.

The problem I’m having is my unbatchify step adds a lot of latency into the pipeline so there is significant deadtime on the gpu before the next batch is run.

Any recommendations to gather batches and not block?

import torch
from random import randint
from itertools import groupby
from operator import itemgetter

BATCH, CHANNELS, TIME = 64, 512, 1000

def generate_data(examples=8):
    """ create example data of various batch sizes """
    for idx in range(1, examples + 1):
        yield ('key-%s' % idx, torch.rand((randint(2, 48), CHANNELS, TIME)))

def batchify(items, batchsize, dim=0):
    """ Batch up multiple examples up to `batch_size`. """
    stack, pos = [], 0
    for k, v in items:
        breaks = range(batchsize - pos, v.shape[dim], batchsize)
        for start, end in zip([0, *breaks], [*breaks, v.shape[dim]]):
            sub_batch = v[start:end]
            stack.append(((k, (pos, pos + end - start)), sub_batch))
            if pos + end - start == batchsize:
                ks, vs = zip(*stack)
                yield ks,
                stack, pos = [], 0
                pos += end - start
    if len(stack):
        ks, vs = zip(*stack)
        yield ks,, dim)

def unbatchify(batches, dim=0):
    """ reconstruct batches to original examples """
    batches = (
        (k, v[start:end])
        for sub_batches, v in batches
        for k, (start, end) in sub_batches
    return (
        (k,[v for (k, v) in group], dim))
        for k, group in groupby(batches, itemgetter(0))

def model(data):
    """ dummy model """
    return data + 1

batches = batchify(generate_data(), batchsize=64)
results = ((key, model(data)) for key, data in batches)

for key, res in unbatchify(results):
    print(key, res.shape)