I have a fairly standard workflow with sequence data (size (N, C, T) representing a single example) where N can vary in size. I have written a keyed batching pipeline that will gather up examples so every forwards on the model will run with a fixed batch size (B, C, T). Many examples might fit into a single batch or one example might span multiple batches.
The problem I’m having is my unbatchify
step adds a lot of latency into the pipeline so there is significant deadtime on the gpu before the next batch is run.
Any recommendations to gather batches and not block?
import torch
from random import randint
from itertools import groupby
from operator import itemgetter
BATCH, CHANNELS, TIME = 64, 512, 1000
def generate_data(examples=8):
""" create example data of various batch sizes """
for idx in range(1, examples + 1):
yield ('key-%s' % idx, torch.rand((randint(2, 48), CHANNELS, TIME)))
def batchify(items, batchsize, dim=0):
""" Batch up multiple examples up to `batch_size`. """
stack, pos = [], 0
for k, v in items:
breaks = range(batchsize - pos, v.shape[dim], batchsize)
for start, end in zip([0, *breaks], [*breaks, v.shape[dim]]):
sub_batch = v[start:end]
stack.append(((k, (pos, pos + end - start)), sub_batch))
if pos + end - start == batchsize:
ks, vs = zip(*stack)
yield ks, torch.cat(vs)
stack, pos = [], 0
else:
pos += end - start
if len(stack):
ks, vs = zip(*stack)
yield ks, torch.cat(vs, dim)
def unbatchify(batches, dim=0):
""" reconstruct batches to original examples """
batches = (
(k, v[start:end])
for sub_batches, v in batches
for k, (start, end) in sub_batches
)
return (
(k, torch.cat([v for (k, v) in group], dim))
for k, group in groupby(batches, itemgetter(0))
)
def model(data):
""" dummy model """
return data + 1
batches = batchify(generate_data(), batchsize=64)
results = ((key, model(data)) for key, data in batches)
for key, res in unbatchify(results):
print(key, res.shape)