Iterabledataset never stops despite StopIteration being raised

I have an IterableDataset where inside __next__() I raise StopIteration when iteration is supposed to finish (__iter__() returns self).

However for some reason the iteration keeps going. Why could this be happening?

Here’s the the complete implementation:

class AudioData(IterableDataset):
    def __init__(self, entries, data_fname):
        self.data_fname = data_fname
        self.entries = entries

        self.data_fh = h5py.File(data_fname, 'r')
        self.feat_d = self.data_fh['feats']
        self.targ_d = self.data_fh['targs']
        self.entry_idx = 0
        self.feat_buffer_size = 1024 * 20 * FS
        self.targ_buffer_size = 1024 * 20 * 100
        self.entries_iter = iter([])
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            self.entries_group = self.entries
        else:
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            num_per_worker = len(self.lst_entries) // num_workers
            self.entries_group = self.entries[worker_id*num_per_worker: (worker_id+1)*num_per_worker]
        self.thread = None

    def __len__(self):
        return len(self.entries_group)

    def _init_buffers(self, entry_idx):
        f_start_idx = self.entries_group[entry_idx].feat_idx
        t_start_idx = self.entries_group[entry_idx].targ_idx
        f_end_idx = f_start_idx
        t_end_idx = t_start_idx
        self.next_entries = []
        buffer_size = 0
        while True:
            if buffer_size >= self.feat_buffer_size or entry_idx >= len(self.entries_group):
                break
            entry = self.entries_group[entry_idx]
            f_end_idx = entry.feat_idx + entry.size
            t_end_idx = entry.targ_idx + entry.steps
            self.next_entries.append(entry)
            entry_idx += 1
           buffer_size += entry.size
        self.next_feat_buffer = self.feat_d[f_start_idx: f_end_idx]
        self.next_targ_buffer = self.targ_d[t_start_idx: t_end_idx]
        self.next_indices = (entry_idx, f_start_idx, t_start_idx,)

    def init_buffers(self):
        if self.entry_idx >= len(self.entries_group):
            return False
        if self.thread is None:
            self.thread = Thread(target=self._init_buffers, args=(self.entry_idx,), daemon=True)
            self.thread.start()
        self.thread.join()
        # No need for copy operation
        self.curr_entries = self.next_entries
        self.feat_buffer = self.next_feat_buffer
        self.targ_buffer = self.next_targ_buffer
        self.entry_idx, self.f_start_idx, self.t_start_idx = self.next_indices
        random.shuffle(self.curr_entries)
        self.entries_iter = iter(self.curr_entries)
        if self.entry_idx < len(self.entries_group):
            self.thread = Thread(target=self._init_buffers, args=(self.entry_idx,), daemon=True)
            self.thread.start()
        return True

    def __next__(self):
        entry = next(self.entries_iter, None)
        if entry is None:
            success = self.init_buffers()
            if not success:
                self.reset()
                raise StopIteration
            entry = next(self.entries_iter, None)
        f = self.feat_buffer[entry.feat_idx - self.f_start_idx: entry.feat_idx - self.f_start_idx + entry.size]
        t = self.targ_buffer[entry.targ_idx - self.t_start_idx: entry.targ_idx - self.t_start_idx + entry.steps]
        return f, t

    def reset(self):
        self.entry_idx = 0

Okay I had to set drop_last=True.

Also I was making the mistake of assuming __init__ would be called again when the dataset is recreated for each dataloader process but that is not the case.