I am new to PyTorch and currently experimenting on PyTorch’s DataLoader on Google Colab. My experiment often requires training time over 12 hours, which is more than what Google Colab offers. Due to this reason, I need to be able to save my optimizer, learning rate scheduler, and the state per specific epoch checkpoint (e.g., every epoch of multitude 5).
I also need to save the data loaded per mini-batch, of which the size is either 32 or 64. However, I could not save the current state of the DataLoader to be reused for other training epochs. How can I save the current state of DataLoader? Thanks in advance.
Note: May be linked with this issue on PyTorch’s Github.
import random
from torch.utils.data import Sampler, DataLoader
class MySampler(Sampler):
def __init__(self, data_source):
self.seq = list(range(len(data_source)))
def __iter__(self):
return iter(self.seq)
dataset = LoadYourDataset()
sampler = MySampler(dataset)
dataloader = DataLoader(dataset, shuffle=False, sampler=sampler)
for epoch in range(0, 999):
random.shuffle(dataloader.sampler.seq)
for i, (x, y) in enumerate(dataloader):
# some code.
# save i and dataloader.sampler.seq
I cannot promise that this code will work (just an idea).
Thank you for the answers. After trying some codes of my own yesterday, I figured out that DataLoader can be saved directly using PyTorch’s torch.save(dataloader_obj, 'dataloader.pth'). The order of data is maintained so far, and the batches as well.
Hello,
I’m on torch 1.2.0 and torch.save seems to work. However, when I try to load (with torch.load) my Dataloader in which I included transforms from torchvision, I get the following error
File "C:\Users\USER\env\lib\site-packages\torch\serialization.py", line 386, in load
return _load(f, map_location, pickle_module, **pickle_load_args)
File "C:\Users\USER\env\lib\site-packages\torch\serialization.py", line 573, in _load
result = unpickler.load()
ModuleNotFoundError: No module named 'transforms'
If you are using this approach: model = torch.load(path) , you would need to make sure that all necessary files are in the corresponding folders as they were while storing the model.
It looks like it does apply here too when I use the same organisation as the place where I saved the dataLoader, when I saved it.
As a side remark: apparently, it is recommended to use the extension “.pt” rather than “.pth” when saving model checkpoints (and maybe for the same reason also data loaders):