I have an IterableDataset where inside __next__()
I raise StopIteration
when iteration is supposed to finish (__iter__()
returns self
).
However for some reason the iteration keeps going. Why could this be happening?
Here’s the the complete implementation:
class AudioData(IterableDataset):
def __init__(self, entries, data_fname):
self.data_fname = data_fname
self.entries = entries
self.data_fh = h5py.File(data_fname, 'r')
self.feat_d = self.data_fh['feats']
self.targ_d = self.data_fh['targs']
self.entry_idx = 0
self.feat_buffer_size = 1024 * 20 * FS
self.targ_buffer_size = 1024 * 20 * 100
self.entries_iter = iter([])
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.entries_group = self.entries
else:
num_workers = worker_info.num_workers
worker_id = worker_info.id
num_per_worker = len(self.lst_entries) // num_workers
self.entries_group = self.entries[worker_id*num_per_worker: (worker_id+1)*num_per_worker]
self.thread = None
def __len__(self):
return len(self.entries_group)
def _init_buffers(self, entry_idx):
f_start_idx = self.entries_group[entry_idx].feat_idx
t_start_idx = self.entries_group[entry_idx].targ_idx
f_end_idx = f_start_idx
t_end_idx = t_start_idx
self.next_entries = []
buffer_size = 0
while True:
if buffer_size >= self.feat_buffer_size or entry_idx >= len(self.entries_group):
break
entry = self.entries_group[entry_idx]
f_end_idx = entry.feat_idx + entry.size
t_end_idx = entry.targ_idx + entry.steps
self.next_entries.append(entry)
entry_idx += 1
buffer_size += entry.size
self.next_feat_buffer = self.feat_d[f_start_idx: f_end_idx]
self.next_targ_buffer = self.targ_d[t_start_idx: t_end_idx]
self.next_indices = (entry_idx, f_start_idx, t_start_idx,)
def init_buffers(self):
if self.entry_idx >= len(self.entries_group):
return False
if self.thread is None:
self.thread = Thread(target=self._init_buffers, args=(self.entry_idx,), daemon=True)
self.thread.start()
self.thread.join()
# No need for copy operation
self.curr_entries = self.next_entries
self.feat_buffer = self.next_feat_buffer
self.targ_buffer = self.next_targ_buffer
self.entry_idx, self.f_start_idx, self.t_start_idx = self.next_indices
random.shuffle(self.curr_entries)
self.entries_iter = iter(self.curr_entries)
if self.entry_idx < len(self.entries_group):
self.thread = Thread(target=self._init_buffers, args=(self.entry_idx,), daemon=True)
self.thread.start()
return True
def __next__(self):
entry = next(self.entries_iter, None)
if entry is None:
success = self.init_buffers()
if not success:
self.reset()
raise StopIteration
entry = next(self.entries_iter, None)
f = self.feat_buffer[entry.feat_idx - self.f_start_idx: entry.feat_idx - self.f_start_idx + entry.size]
t = self.targ_buffer[entry.targ_idx - self.t_start_idx: entry.targ_idx - self.t_start_idx + entry.steps]
return f, t
def reset(self):
self.entry_idx = 0