I have a fairly standard workflow with sequence data (size (N, C, T) representing a single example) where N can vary in size. I have written a keyed batching pipeline that will gather up examples so every forwards on the model will run with a fixed batch size (B, C, T). Many examples might fit into a single batch or one example might span multiple batches.
The problem I’m having is my
unbatchify step adds a lot of latency into the pipeline so there is significant deadtime on the gpu before the next batch is run.
Any recommendations to gather batches and not block?
import torch from random import randint from itertools import groupby from operator import itemgetter BATCH, CHANNELS, TIME = 64, 512, 1000 def generate_data(examples=8): """ create example data of various batch sizes """ for idx in range(1, examples + 1): yield ('key-%s' % idx, torch.rand((randint(2, 48), CHANNELS, TIME))) def batchify(items, batchsize, dim=0): """ Batch up multiple examples up to `batch_size`. """ stack, pos = , 0 for k, v in items: breaks = range(batchsize - pos, v.shape[dim], batchsize) for start, end in zip([0, *breaks], [*breaks, v.shape[dim]]): sub_batch = v[start:end] stack.append(((k, (pos, pos + end - start)), sub_batch)) if pos + end - start == batchsize: ks, vs = zip(*stack) yield ks, torch.cat(vs) stack, pos = , 0 else: pos += end - start if len(stack): ks, vs = zip(*stack) yield ks, torch.cat(vs, dim) def unbatchify(batches, dim=0): """ reconstruct batches to original examples """ batches = ( (k, v[start:end]) for sub_batches, v in batches for k, (start, end) in sub_batches ) return ( (k, torch.cat([v for (k, v) in group], dim)) for k, group in groupby(batches, itemgetter(0)) ) def model(data): """ dummy model """ return data + 1 batches = batchify(generate_data(), batchsize=64) results = ((key, model(data)) for key, data in batches) for key, res in unbatchify(results): print(key, res.shape)