I have several timeseries that I want to classify with RNNs. I would like to have the get_item method take in either a dict or tuple, specifying both time idx and sequence idx. I have implemented that here.
class MLDataWrangler(zrfr.ZRFReader, torch.utils.data.Dataset):
...
...
def __len__(self) -> int:
# return len(self._all_flexions) # number of timeseries's
return len(self._all_flexions), self._sequence_length # can I return multiple lengths?
def __getitem__(self, idx) -> Tuple[np.ndarray, Dict[str, bool]]:
if isinstance(idx, dict):
frames = idx["frames"]
sequence = idx["flexion"]
elif hasattr(idx, "__iter__"):
frames = idx[1]
sequence = idx[0]
zrf_path = self._all_flexions[sequence]
video_pixels = self.read(zrf_path).data
# signal level (input)
if self._signame == SIGNAL_NAMES[1]:
img_pixels = video_pixels.zsig[frames].astype(np.float32) # new since stacked
elif self._signame == SIGNAL_NAMES[0]:
img_pixels = video_pixels.rfsig[frames].astype(np.float32) # new since stacked
# get finger ground truth
labels = self._data_labels[sequence][1] # since stacked
labels = np.array([labels[fing] for fing in FINGER_NAMES]).astype(np.float32) # convert dict to np array
# transform as necessary
if self.transform:
img_pixels = self.transform(img_pixels)
if self.target_transform:
labels = self.target_transform(labels)
return img_pixels.astype(np.float32), labels
My question is how do I use DataLoader on this? I assume I need a custom Sampler but I’ve tried searching and can’t figure out what I need to implement.