IterableDataset how to implemtent __iter__

Hey guy,code like this

class RSNA_volume(IterableDataset):
    def __init__(self, stage2_train=None, instance_dir=None):
        super(RSNA_volume, self).__init__()
        # self.instances = glob.glob(str(Path(stage2_train) / "**" / "*.nii.gz"))
        self.instances = stage2_train
        self.instance_label = pd.read_csv(filepath_or_buffer=instance_dir, usecols=['InstanceID', 'Label'])
        self.number_of_instances = len(self.instances)
        LOGGER.info(f"Instances number in RSNA: {self.number_of_instances}")
        LOGGER.info(f"Instances label shape: {self.instance_label.shape}")

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        if self.count == self.number_of_instances:
            raise StopIteration
        instance = self.instances[self.count]
        data = tio.ScalarImage(instance)

        instance_name = Path(instance).stem.split('.')[0]
        label = self.instance_label[self.instance_label['InstanceID'] == instance_name]['Label'].tolist()
        label = eval(label[0])
        # print(label)
        self.count += 1

        return data, torch.tensor(data=label)

    def __len__(self):
        return self.number_of_instances

and this is the torch turtorial example

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))

My question is how to change my code to be like the example.
For the reason I change it,cause I think my next is redundant,i check the code of dataloader,and find that next (iterator) has been implemented,so i think maye the code in next should be in iter,but i don’t know how to change it.I don;t know how to use yield. :smiling_face_with_tear: :smiling_face_with_tear:
Thanks.

Hi,

May I ask why you need an IterableDataset. I think this should work fine with a normal Dataset object since you are just reading a csv file and then reading images when you have to give the next item.

The tutorial in this link does the same thing, reading a file and then giving an image that corresponds to some path specified in the csv file.

The DataLoader will then take care to iterate through the dataset.

This is the code for the Dataset from the tutorial

class FaceLandmarksDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

This would be then for the DataLoader

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=0)

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

Thanks for your reply.
I know how to custome dataset in map-style.I’d like to implement it in iterable-style. :joy::joy:
And i don’t know to implement it.:joy::joy:

Oh ok,

then here is a blog with multiple examples going step by step into how to build one. This might be helpful.