Hello PyTorch community,
I am exploring ways to speed up the loading process for a serialized tensor file in PyTorch.
Here’s my scenario:
I have a single large tensor that I previously saved using torch.save()
into a file on local disk, say tensor.pt
. Now, I want to optimize the time it takes to load this tensor back using torch.load(tensor.pt)
. As I understand, torch.load
uses Python’s unpickling mechanism, which is also single-threaded. I’m wondering if there are ways to speed up this loading process.
Some ideas I’ve considered:
- Splitting the tensor files: Save the one large tensor into multiple files, each containing a chunk of the tensor. Then, use multithreading to load these files in parallel and reconstruct the original tensor.
- Alternative formats: Serialize the tensor into formats like Pandas DataFrame or NumPy ndarray, and explore multithreaded deserialization options. After loading, perform a zero-copy conversion to a PyTorch tensor if possible.
So far, I haven’t found an elegant solution to this problem. I’m open to suggestions beyond .pt
files, as my goal is more general—optimizing the loading of serialized binary tensor files. The .pt
file is just the simplest example as a tensor file.
Additional context:
- My primary operation involves performing mathematical computations on the tensor once it’s loaded. I don’t think
Dataset
orDataLoader
fits this use case, as this is not about shuffling, sampling or batching—it’s just about efficiently loading one large tensor.
Any hints, advice, or alternative approaches would be greatly appreciated!