Hello I face a StopIteration error when testing my pytorch Dataset class for an image segmentation problem I am currently tackling. My Dataset Class is as follows:
class SegDataset(Dataset):
def __init__(self, df, fold=fold, train=True, augments=None):
self.df = df
self.fold = fold
self.train = train
self.augments = augments
skf = StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=SEED)
ids = df['id'].values
labels = df['organ'].values
ids = set(ids[list(skf.split(ids, labels))[self.fold][0 if self.train else 1]])
self.fnames = [fname for fname in os.listdir(PREPROC_TRAIN_PATH) if fname.split('_')[0] in ids]
def __len__(self):
return len(self.fnames)
def __get_item__(self, idx):
fname = self.fnames[idx]
img_path = os.path.join(PREPROC_TRAIN_PATH, fname)
mask_path = os.path.join(PREPROC_MASK_PATH, fname)
image = cv2.cvtColor(cv2.imread(img_path, cv2.COLOR_BGR2RGB))
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
image = (image.astype(np.float32)/255.0 - mean)/std
mask = image.astype(np.float32)/255.0
if self.augments is not None:
augmented = self.augments(image=image, mask=mask)
image, mask = augmented['image'], augmented['mask']
return img2tensor(image), img2tensor(mask)
In case you are wondering here is the img2tensor function if it helps:
def img2tensor(image, dtype:np.dtype = np.float32):
# for masks
if image.ndim == 2:
image = np.expand_dims(image,2)
image = np.transpose(image,(2,0,1)) # C , H , W
image = np.ascontiguousarray(image)
return torch.from_numpy(image.astype(dtype, copy=False))
The error generating code
ds = SegDataset(train_df)
dl = DataLoader(ds, batch_size=8)
imgs, msks = next(iter(dl))
The error message
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
<timed exec> in <module>
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
528 if self._sampler_iter is None:
529 self._reset()
--> 530 data = self._next_data()
531 self._num_yielded += 1
532 if self._dataset_kind == _DatasetKind.Iterable and \
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
567
568 def _next_data(self):
--> 569 index = self._next_index() # may raise StopIteration
570 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
571 if self._pin_memory:
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_index(self)
519
520 def _next_index(self):
--> 521 return next(self._sampler_iter) # may raise StopIteration
522
523 def _next_data(self):
StopIteration:
I do not have much experience in machine learning in general so any help will be appreciated.