Fast pytorch processing of Tensorflow tfrecords data

Hi all,

I want to devise an efficient way to load in data from a set of (relatively large) tfrecords files then pass said data on to my pytorch model for training and inference. We’re talking about ca. 64 GB of data in total right now - but the idea is to scale this to much larger datasets in the future. The tfrecords have been generated using the tfds API - one sample consists of 3 tensors packed into a dict: “X_hr”, “X_lr” (high-res and low-res inputs) and the target “Y” (this is a super-resolution problem). I can’t duplicate the data - i.e. reading the tfrecords in tensorflow then saving said data into a more pytorch-friendly format is not an option.

Here is how the data look like:

What would be the best way to go about this?

1/ Write a custom torch.Dataset that wraps around a TFRecordsDataset object? The TFRecordsDataset would be in charge of decoding / unzipping the binary data and producing samples as needed.

2/ Write a custom DataLoader that accepts a TFRecordsDataset object, something similar to G2Net: Read from TFRecord & Train with PyTorch | Kaggle? I’m not sure if this would work well with multiprocessing.

3/ Hook up the TFRecordsDataset directly to my Pytorch model (how?)

4/ Forget about pytorch and do end-to-end TF for this task?! (not a real option! :))

FWIW I am aware of GitHub - vahidk/tfrecord: TFRecord reader for PyTorch but I can’t get it to work with my ZLIB-compressed tfrecords features :frowning: