Efficient way to convert multiple PIL images to 5D batch tensor

I’m training a network that takes three RGB images as input. I have them stored as bytes in a LMDB file and I’m using tensorpack’s Dataflow to read them.

For each batch, I need to convert my images and ground truth to PyTorch tensors, yielding the images as a [N, 3, 3, 64, 64] Tensor.

Currently, I’m doing it the “easy” way with a nested loop to convert each byte array to a PIL Image and subsequentialy a tensor, and then stacking each set of three images into the final 5D batch tensor.

crops_tensor = torch.stack([torch.stack([transform_fn(Image.open(io.BytesIO(crop))) for crop in crops]) for crops in crops_batch])

where crops_batch is yielded from the Dataflow at each iteration.

As you can imagine, this is a huge bottleneck for my script (as seen in the snakeviz profile below).

What is a more efficient way for me to go from N 3D byte arrays to the 5D tensor used in my training routine?