I have a custom Dataset object, which I am iterating over using a Dataloader.
Dataset code:
class PatientDataset(data.Dataset):
def __init__(self, datadir, list_IDs):
self.datadir = os.path.join('data',datadir)
self.list_IDs = list_IDs
self.all_labels = {}
avbl_dicts = {x.replace('_label_dict.pkl', '') : utils.pkl_load(os.path.join(self.datadir,x)) \
for x in os.listdir(self.datadir) if '_label_dict.pkl' in x} # {tgt_type : tgt_dict}
for ID in self.list_IDs:
self.all_labels[ID] ={tgt_type:tgt_dict[ID] for tgt_type, tgt_dict in avbl_dicts.items()}
def __len__(self):
return len(self.list_IDs)
def __getitem__(self, index):
ID = self.list_IDs[index]
X = utils.pkl_load(os.path.join(self.datadir, 'sequences', ID+'.npy')) # Numpy seq
y = self.all_labels[ID] # Dict of available target values
return X, y`
Iteration code:
def fn_that_iterates(dl):
for batch, (X,y) in enumerate(dl):
# do something
test_ds = pickle_load('test.dataset')) # Loads the pickled dataset object
dl = DataLoader(test_ds, batch_size=len(test_ds))
#iterate
fn_that_iterates(dl)
Sometimes when I run this as a script, I get a collate_fn TypeError because it got <PatientDataset>
instead of the (numpy array, dict)
that my Dataset __getitem__
returns.
Running the same code in IPython never gives me the error.
Has anyone else faced this? Do we know why this happens?