Hi, I am a beginner for Pytorch but have experience using Tensorflow.
Currently, I want custom a Dataset to load some .npy data from HHD streamingly. At same time, I also want drop out some elements that not meet conditions from pathlist
.
I try to use __getitem__()
in torch.utils.data.Dataset
, but index count will raise error due to the variable length of pathlist
(I pop() some bad data). Maybe it is because __len__
only be recalled once at batch size counting?
My question is what should I do to re-count length and index before iterating. Thank you!
Here is my code:
class NPY_Dataset(torch.utils.data.Dataset):
def __init__(self, abspath_filelist: list[str]) -> None:
super().__init__()
self.abspath_list = abspath_filelist
def __len__(self):
return len(self.abspath_list)
def __getitem__(self, idx) -> np.ndarray:
data_path = self.abspath_list[idx]
sample = np.load(data_path, allow_pickle=True)
sample = sample.astype(np.float32)
# data check
if len(sample) < custom_size:
self.abspath_list.pop(idx)
return self.__getitem__(idx)
return sample