Programmatic extention for Custom Dataset error

In the context of creating my own Dataset to feed into a pytorch DataLoader,
I have designed a way to inherit from a class programmatically, so basically extending a class that’s going to be used as a Dataset, in order to add ‘custom’ functionality to it. The dynamic extention works nicely. However, PyTorch doesn’t like it, and when I start iterating the DataLoader based on it, it complains.

Here is a toy example for the extended class:

# Mock dataset. This has to be on a different file for some reason
from torch.utils.data import Dataset
class MyDataset(Dataset):
    def __init__(self):
        self.first = 1
        self.second = 2

    def __len__(self):
        return 1000

    def __getitem__(self, item):
        return self.first, self.second
import pickle
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from utils import MyDataset


def extend_class(base_class):
    class B(base_class):
        def hello(self):
            print('Yo!')

    return B


if __name__ == '__main__':
    a = MyDataset()
    dataloader = DataLoader(a, batch_size=4, shuffle=True, num_workers=1)

    iterator = iter(dataloader)
    first, second = next(iterator)  # this works ok

    extended_class = extend_class(MyDataset)

    b = extended_class()
    b.hello() # this works!

    dataloader = DataLoader(b, batch_size=4, shuffle=True, num_workers=1)

    iterator = iter(dataloader)  # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B'

    first, second = next(iterator)

Any workaround to do this is appreciated!

The code snippet works on my machine.
If I add print(first, second) as the last line of code, I get:

Yo!
tensor([1, 1, 1, 1]) tensor([2, 2, 2, 2])

what version of pickle do you have?

pickle.format_version returns 4.0. I’m also using PyTorch 1.6.0.dev20200611.

Weird. I have the same pickle version and I just updated it on 1.7.0.dev20200626.
I still got the same problem

Traceback (most recent call last):
  File "C:\Users\valer\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-2cd705ece2e1>", line 29, in <module>
    iterator = iter(dataloader)  # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B'
  File "C:\Users\valer\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 282, in __iter__
    return _MultiProcessingDataLoaderIter(self)
  File "C:\Users\valer\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 728, in __init__
    w.start()
  File "C:\Users\valer\Anaconda3\lib\multiprocessing\process.py", line 112, in start
    self._popen = self._Popen(self)
  File "C:\Users\valer\Anaconda3\lib\multiprocessing\context.py", line 223, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "C:\Users\valer\Anaconda3\lib\multiprocessing\context.py", line 322, in _Popen
    return Popen(process_obj)
  File "C:\Users\valer\Anaconda3\lib\multiprocessing\popen_spawn_win32.py", line 89, in __init__
    reduction.dump(process_obj, to_child)
  File "C:\Users\valer\Anaconda3\lib\multiprocessing\reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'extend_class.<locals>.B'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "C:\Users\valer\Anaconda3\lib\multiprocessing\spawn.py", line 105, in spawn_main
    exitcode = _main(fd)
  File "C:\Users\valer\Anaconda3\lib\multiprocessing\spawn.py", line 115, in _main
    self = reduction.pickle.load(from_parent)
EOFError: Ran out of input

I also tried it on PyTorch 1.5.0, same result.
I wonder if anyone else can kindly try the code snippet and report. Thanks

this works on my server machine which has a GPU and PyTorch 1.4.0 and runs on linux (my local machine runs Windows). Does it have to do with multiprocessing and/or with the OS? I have tried this version of the code but it doesn’t work.

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from utils import MyDataset
from multiprocessing.dummy import freeze_support


def extend_class(base_class):
    class B(base_class):
        def hello(self):
            print('Yo!')

    return B


def do_stuff():
    a = MyDataset()
    dataloader = DataLoader(a, batch_size=4, shuffle=True, num_workers=1)

    iterator = iter(dataloader)
    first, second = next(iterator)  # this works ok

    extended_class = extend_class(MyDataset)

    b = extended_class()
    b.hello() # this works!

    dataloader = DataLoader(b, batch_size=4, shuffle=True, num_workers=1)

    iterator = iter(dataloader)  # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B'

    first, second = next(iterator)

if __name__ == '__main__':
    freeze_support()
    do_stuff()

I’ve tested it on my Windows Laptop, so it doesn’t seem to be OS-specific.
It might be related to multiprocessing somehow, but I don’t know how and why it would run on my system.
Also, I’m using Python3.7 in case that matters.

Thanks for your reply. I am also using Python 3.7. Are you running it on a system with a GPU? Could you try to momentarely disable the GPU? Maybe it depends on the PyTorch version I have installed on my system? (cpu version only, obviously)