How to Save DataLoader?

Hi,

I am new to PyTorch and currently experimenting on PyTorch’s DataLoader on Google Colab. My experiment often requires training time over 12 hours, which is more than what Google Colab offers. Due to this reason, I need to be able to save my optimizer, learning rate scheduler, and the state per specific epoch checkpoint (e.g., every epoch of multitude 5).

I also need to save the data loaded per mini-batch, of which the size is either 32 or 64. However, I could not save the current state of the DataLoader to be reused for other training epochs. How can I save the current state of DataLoader? Thanks in advance.

Note: May be linked with this issue on PyTorch’s Github.

1 Like

Can’t we use random seed and get the same list ? Then, knowing the index will suffice.

Maybe

import random
from torch.utils.data import Sampler, DataLoader

class MySampler(Sampler):
    def __init__(self, data_source):
        self.seq = list(range(len(data_source)))
    def __iter__(self):
        return iter(self.seq)

dataset = LoadYourDataset()
sampler = MySampler(dataset)
dataloader = DataLoader(dataset, shuffle=False, sampler=sampler)

for epoch in range(0, 999):
    random.shuffle(dataloader.sampler.seq)
    for i, (x, y) in enumerate(dataloader):
        # some code.
        # save i and dataloader.sampler.seq

I cannot promise that this code will work (just an idea).

Thank you for the answers. After trying some codes of my own yesterday, I figured out that DataLoader can be saved directly using PyTorch’s torch.save(dataloader_obj, 'dataloader.pth'). The order of data is maintained so far, and the batches as well.

7 Likes

How did you load after saving it ? Can you provide sample code please ?

1 Like

Oh just use the torch.load as usual if I’m not mistaken.

1 Like

Thanks, I tried by creating a DataLoader() object and then loading into it and got errors. It seems just torch.load() is enough.

Thanks

Actually, I can’t do torch.save on a dataloader - is it for sure working on your side? If yes, what pytorch version do you use?

It worked on Pytorch 0.4.0 if I’m not mistaken.

Hello,
I’m on torch 1.2.0 and torch.save seems to work. However, when I try to load (with torch.load) my Dataloader in which I included transforms from torchvision, I get the following error

 File "C:\Users\USER\env\lib\site-packages\torch\serialization.py", line 386, in load
    return _load(f, map_location, pickle_module, **pickle_load_args)
  File "C:\Users\USER\env\lib\site-packages\torch\serialization.py", line 573, in _load
    result = unpickler.load()
ModuleNotFoundError: No module named 'transforms'

Any idea how to solve this ?

Have you tried importing the torchvision.transforms module? Also could you verify that torchvision is installed @Fanny ?

Yes @Noiran_Allen , the module is imported in my script and when I try to use transforms, for example with tr = transforms.Resize(248), it does work.

I think I found out the problem here. According to this topic ModuleNotFoundError: No module named 'network' - #3 by Oscar_Rangel concerning the saving of full models, ptrblck said that :

If you are using this approach: model = torch.load(path) , you would need to make sure that all necessary files are in the corresponding folders as they were while storing the model.

It looks like it does apply here too when I use the same organisation as the place where I saved the dataLoader, when I saved it.

As a side remark: apparently, it is recommended to use the extension “.pt” rather than “.pth” when saving model checkpoints (and maybe for the same reason also data loaders):