In the context of creating my own Dataset to feed into a pytorch DataLoader,
I have designed a way to inherit from a class programmatically, so basically extending a class that’s going to be used as a Dataset, in order to add ‘custom’ functionality to it. The dynamic extention works nicely. However, PyTorch doesn’t like it, and when I start iterating the DataLoader based on it, it complains.
Here is a toy example for the extended class:
# Mock dataset. This has to be on a different file for some reason
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
self.first = 1
self.second = 2
def __len__(self):
return 1000
def __getitem__(self, item):
return self.first, self.second
import pickle
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from utils import MyDataset
def extend_class(base_class):
class B(base_class):
def hello(self):
print('Yo!')
return B
if __name__ == '__main__':
a = MyDataset()
dataloader = DataLoader(a, batch_size=4, shuffle=True, num_workers=1)
iterator = iter(dataloader)
first, second = next(iterator) # this works ok
extended_class = extend_class(MyDataset)
b = extended_class()
b.hello() # this works!
dataloader = DataLoader(b, batch_size=4, shuffle=True, num_workers=1)
iterator = iter(dataloader) # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B'
first, second = next(iterator)
Any workaround to do this is appreciated!
The code snippet works on my machine.
If I add print(first, second)
as the last line of code, I get:
Yo!
tensor([1, 1, 1, 1]) tensor([2, 2, 2, 2])
what version of pickle do you have?
pickle.format_version
returns 4.0
. I’m also using PyTorch 1.6.0.dev20200611
.
Weird. I have the same pickle version and I just updated it on 1.7.0.dev20200626
.
I still got the same problem
Traceback (most recent call last):
File "C:\Users\valer\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-2cd705ece2e1>", line 29, in <module>
iterator = iter(dataloader) # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B'
File "C:\Users\valer\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 282, in __iter__
return _MultiProcessingDataLoaderIter(self)
File "C:\Users\valer\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 728, in __init__
w.start()
File "C:\Users\valer\Anaconda3\lib\multiprocessing\process.py", line 112, in start
self._popen = self._Popen(self)
File "C:\Users\valer\Anaconda3\lib\multiprocessing\context.py", line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "C:\Users\valer\Anaconda3\lib\multiprocessing\context.py", line 322, in _Popen
return Popen(process_obj)
File "C:\Users\valer\Anaconda3\lib\multiprocessing\popen_spawn_win32.py", line 89, in __init__
reduction.dump(process_obj, to_child)
File "C:\Users\valer\Anaconda3\lib\multiprocessing\reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'extend_class.<locals>.B'
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "C:\Users\valer\Anaconda3\lib\multiprocessing\spawn.py", line 105, in spawn_main
exitcode = _main(fd)
File "C:\Users\valer\Anaconda3\lib\multiprocessing\spawn.py", line 115, in _main
self = reduction.pickle.load(from_parent)
EOFError: Ran out of input
I also tried it on PyTorch 1.5.0
, same result.
I wonder if anyone else can kindly try the code snippet and report. Thanks
this works on my server machine which has a GPU and PyTorch 1.4.0
and runs on linux (my local machine runs Windows). Does it have to do with multiprocessing and/or with the OS? I have tried this version of the code but it doesn’t work.
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from utils import MyDataset
from multiprocessing.dummy import freeze_support
def extend_class(base_class):
class B(base_class):
def hello(self):
print('Yo!')
return B
def do_stuff():
a = MyDataset()
dataloader = DataLoader(a, batch_size=4, shuffle=True, num_workers=1)
iterator = iter(dataloader)
first, second = next(iterator) # this works ok
extended_class = extend_class(MyDataset)
b = extended_class()
b.hello() # this works!
dataloader = DataLoader(b, batch_size=4, shuffle=True, num_workers=1)
iterator = iter(dataloader) # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B'
first, second = next(iterator)
if __name__ == '__main__':
freeze_support()
do_stuff()
I’ve tested it on my Windows Laptop, so it doesn’t seem to be OS-specific.
It might be related to multiprocessing somehow, but I don’t know how and why it would run on my system.
Also, I’m using Python3.7 in case that matters.
Thanks for your reply. I am also using Python 3.7. Are you running it on a system with a GPU? Could you try to momentarely disable the GPU? Maybe it depends on the PyTorch version I have installed on my system? (cpu version only, obviously)
I wanted to update. This works even on my windows machine without GPU if I set num_workers=0