I am trying to use num_workers with my data loader, however it fails due to a PicklingError. Here is the error:
<ipython-input-51-2646a02f621a> in train_SN(model, optimizer, scheduler, episodes)
37 targets = torch.zeros(NUM_CL,(NUM_EX+NUM_CL-1)).to(device=device, dtype=dtype)
38 sample_count = 0
---> 39 for i, (sample, sample_label) in enumerate(train_sample_loader):
40 sample_count += 1
41 idx = 0
C:\ProgramData\Miniconda3\lib\site-packages\torch\utils\data\dataloader.py in __iter__(self)
499
500 def __iter__(self):
--> 501 return _DataLoaderIter(self)
502
503 def __len__(self):
C:\ProgramData\Miniconda3\lib\site-packages\torch\utils\data\dataloader.py in __init__(self, loader)
287 for w in self.workers:
288 w.daemon = True # ensure that the worker exits on process exit
--> 289 w.start()
290
291 _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
C:\ProgramData\Miniconda3\lib\multiprocessing\process.py in start(self)
103 'daemonic processes are not allowed to have children'
104 _cleanup()
--> 105 self._popen = self._Popen(self)
106 self._sentinel = self._popen.sentinel
107 # Avoid a refcycle if the target function holds an indirect
C:\ProgramData\Miniconda3\lib\multiprocessing\context.py in _Popen(process_obj)
210 @staticmethod
211 def _Popen(process_obj):
--> 212 return _default_context.get_context().Process._Popen(process_obj)
213
214 class DefaultContext(BaseContext):
C:\ProgramData\Miniconda3\lib\multiprocessing\context.py in _Popen(process_obj)
311 def _Popen(process_obj):
312 from .popen_spawn_win32 import Popen
--> 313 return Popen(process_obj)
314
315 class SpawnContext(BaseContext):
C:\ProgramData\Miniconda3\lib\multiprocessing\popen_spawn_win32.py in __init__(self, process_obj)
64 try:
65 reduction.dump(prep_data, to_child)
---> 66 reduction.dump(process_obj, to_child)
67 finally:
68 context.set_spawning_popen(None)
C:\ProgramData\Miniconda3\lib\multiprocessing\reduction.py in dump(obj, file, protocol)
57 def dump(obj, file, protocol=None):
58 '''Replacement for pickle.dump() using ForkingPickler.'''
---> 59 ForkingPickler(file, protocol).dump(obj)
60
61 #
PicklingError: Can't pickle <function <lambda> at 0x000001D69AAC4E18>: attribute lookup <lambda> on __main__ failed
I have a suspicion that my use of a custom sampler class causes the issue. I use the ImageFolder pytorch class on my dataset. Then I use a custom sampler on the data. These are my samplers which randomly select batches to use for one-shot learning:
###############################################
#Build Custom Sampler Classes
###############################################
class SampleSampler(Sampler):
'''Samples 'num_inst' examples each from 'num_cl' groups.
for one shot learning, num_inst is 1 for sample group.
'total_inst' per class, 'total_cl' classes'''
def __init__(self, num_cl=20, total_cl=963, num_inst=1, total_inst=20, shuffle=True):
self.num_cl = num_cl
self.total_cl = total_cl
self.num_inst = num_inst
self.total_inst = total_inst
self.cl_list = list(np.random.choice(total_cl, num_cl, replace=False))
self.ex_list = list(np.random.randint(total_inst, size=num_inst*20))
self.shuffle = shuffle
batch = []
for i, cl in enumerate(self.cl_list):
batch = batch + [20*cl+self.ex_list[i]]
mix = batch[:]
if self.shuffle:
np.random.shuffle(mix)
self.batch = batch
self.mix = mix
def __iter__(self):
# return a single list of indices, assuming that items are grouped 20 per class
if self.shuffle:
return iter(self.mix)
else:
return iter(self.batch)
# the following functions help you retrieve instances
# index of original dataset will be 20*class + example
def get_classes(self):
return self.cl_list
def get_examples(self):
return self.ex_list
def get_batch_idc(self):
return self.batch
def __len__(self):
return len(self.batch)
class QuerySampler(Sampler):
'''Samples queries based on class list and example list'''
def __init__(self, cl_list, ex_list, num_inst=19, shuffle=False):
self.cl_list = cl_list
self.ex_list = ex_list
self.num_inst = num_inst
self.shuffle = shuffle
batch = []
for i, cl in enumerate(self.cl_list):
remaining_ex = list(range(20))
remaining_ex.remove(self.ex_list[i])
queries = random.sample(remaining_ex, self.num_inst)
for query in queries:
batch = batch + [20*cl+query]
if self.shuffle:
np.random.shuffle(batch)
self.batch = batch
def __iter__(self):
# return a single list of indices, assuming that items are grouped 20 per class
return iter(self.batch)
# the following functions help you retrieve instances
# index of original dataset will be 20*class + example
def get_classes(self):
return self.cl_list
def get_examples(self):
return self.ex_list
def get_batch_idc(self):
return self.batch
def __len__(self):
return len(self.batch)
What exactly is the PicklingError failing on and how would I fix this? Is it the sampler’s issue or something else entirely?