Unstable behavior while iterating over Dataloader (collate_fn TypeError issue)

I have a custom Dataset object, which I am iterating over using a Dataloader.

Dataset code:

class PatientDataset(data.Dataset):  
    def __init__(self, datadir, list_IDs): 
        self.datadir = os.path.join('data',datadir)
        self.list_IDs = list_IDs
        self.all_labels = {}
        avbl_dicts = {x.replace('_label_dict.pkl', '') : utils.pkl_load(os.path.join(self.datadir,x)) \
                        for x in os.listdir(self.datadir) if '_label_dict.pkl' in x} # {tgt_type : tgt_dict}
                        
        for ID in self.list_IDs:
            self.all_labels[ID] ={tgt_type:tgt_dict[ID] for tgt_type, tgt_dict in avbl_dicts.items()}

    def __len__(self):
        return len(self.list_IDs)

    def __getitem__(self, index):
        ID = self.list_IDs[index]
        X = utils.pkl_load(os.path.join(self.datadir, 'sequences', ID+'.npy')) # Numpy seq
        y = self.all_labels[ID] # Dict of available target values
        return X, y`

Iteration code:

def fn_that_iterates(dl):
    for batch, (X,y) in enumerate(dl):
          # do something

test_ds = pickle_load('test.dataset')) # Loads the pickled dataset object
dl = DataLoader(test_ds, batch_size=len(test_ds))
#iterate
fn_that_iterates(dl)

Sometimes when I run this as a script, I get a collate_fn TypeError because it got <PatientDataset> instead of the (numpy array, dict) that my Dataset __getitem__ returns.

Running the same code in IPython never gives me the error.

Has anyone else faced this? Do we know why this happens?

How did you save the test.dataset and what’s your use case to store it?
Based on the posted code it seems you are lazily loading the data anyway, so storing the Dataset object might be unnecessary.
Are you seeing the same error, if you recreate the PatientDataset and pass it to the DataLoader?

Thank you for your response @ptrblck. It is now running as expected; this happened without a single code change, so I am very confused.
I pickle.dump()ed it to a binary file. I’m storing it because I need to use the same datasets for a series of experiments.
I didn’t get the error when I run the above in IPython, or if I recreate the Dataset and pass it to the DataLoader. It was only happening when I ran it in a script from the command line.

1 Like