Hey guy,code like this
class RSNA_volume(IterableDataset):
def __init__(self, stage2_train=None, instance_dir=None):
super(RSNA_volume, self).__init__()
# self.instances = glob.glob(str(Path(stage2_train) / "**" / "*.nii.gz"))
self.instances = stage2_train
self.instance_label = pd.read_csv(filepath_or_buffer=instance_dir, usecols=['InstanceID', 'Label'])
self.number_of_instances = len(self.instances)
LOGGER.info(f"Instances number in RSNA: {self.number_of_instances}")
LOGGER.info(f"Instances label shape: {self.instance_label.shape}")
def __iter__(self):
self.count = 0
return self
def __next__(self):
if self.count == self.number_of_instances:
raise StopIteration
instance = self.instances[self.count]
data = tio.ScalarImage(instance)
instance_name = Path(instance).stem.split('.')[0]
label = self.instance_label[self.instance_label['InstanceID'] == instance_name]['Label'].tolist()
label = eval(label[0])
# print(label)
self.count += 1
return data, torch.tensor(data=label)
def __len__(self):
return self.number_of_instances
and this is the torch turtorial example
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
My question is how to change my code to be like the example.
For the reason I change it,cause I think my next is redundant,i check the code of dataloader,and find that next (iterator) has been implemented,so i think maye the code in next should be in iter,but i don’t know how to change it.I don;t know how to use yield.
Thanks.