In the context of creating my own Dataset to feed into a pytorch DataLoader,
I have designed a way to inherit from a class programmatically, so basically extending a class that’s going to be used as a Dataset, in order to add ‘custom’ functionality to it. The dynamic extention works nicely. However, PyTorch doesn’t like it, and when I start iterating the DataLoader based on it, it complains.
Here is a toy example for the extended class:
# Mock dataset. This has to be on a different file for some reason from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self): self.first = 1 self.second = 2 def __len__(self): return 1000 def __getitem__(self, item): return self.first, self.second
import pickle from torch.utils.data import Dataset from torch.utils.data import DataLoader from utils import MyDataset def extend_class(base_class): class B(base_class): def hello(self): print('Yo!') return B if __name__ == '__main__': a = MyDataset() dataloader = DataLoader(a, batch_size=4, shuffle=True, num_workers=1) iterator = iter(dataloader) first, second = next(iterator) # this works ok extended_class = extend_class(MyDataset) b = extended_class() b.hello() # this works! dataloader = DataLoader(b, batch_size=4, shuffle=True, num_workers=1) iterator = iter(dataloader) # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B' first, second = next(iterator)
Any workaround to do this is appreciated!