Read dataset from TFRecord format

Answering my own question:

One needs TensorFlow installed to read TFRecords (yack! I hoped to avoid this). The reason for this is that there are record guards and a checksum that they put into the file, in addition to the ProtocolBuffer payload. See there.

To demonstrate how to read/write TFRecords I put a tiny project here - check it out.

It should be easy now to import TFRecord data into PyTorch by just wrapping arrays into torch.Tensors

@ptrblck I do not want to transform from TFRecord. I am planning to use the data as-is. Reason is that TFRecord io supports cloud storage out-of-the-box. Example - these are valid filenames: gs://mybucket/training/data/blah.tfrecords, or s3://mybucket/training/data/foo.tfrecords.

This is very convenient as my training process uses standard disposable cloud workers that should not store anything of a value on their local drives! So the plan is:

  • install both tf and torch
  • read data from TFRecord into torch.Tensor
  • hack torch.util.data.DatasetLoader to be able to read streaming data (no len!!!). And throw all the existing Torch Dataset machinery under the bus - it is based on random-access model, alas.
  • prove that this is reasonably fast and is not a bottleneck for training

Thanks for reading!

10 Likes