How can I use a multi-indexed Dataset?

I have several timeseries that I want to classify with RNNs. I would like to have the get_item method take in either a dict or tuple, specifying both time idx and sequence idx. I have implemented that here.

class MLDataWrangler(zrfr.ZRFReader, torch.utils.data.Dataset):
   ...
   ...
  
  def __len__(self) -> int:
    # return len(self._all_flexions) # number of timeseries's
    return len(self._all_flexions), self._sequence_length # can I return multiple lengths?

  def __getitem__(self, idx) -> Tuple[np.ndarray, Dict[str, bool]]:
    if isinstance(idx, dict):
      frames = idx["frames"]
      sequence = idx["flexion"]
    elif hasattr(idx, "__iter__"):
      frames = idx[1]
      sequence = idx[0]
    

    zrf_path = self._all_flexions[sequence]
    video_pixels = self.read(zrf_path).data


    # signal level (input)
    if self._signame == SIGNAL_NAMES[1]:
      img_pixels = video_pixels.zsig[frames].astype(np.float32) # new since stacked
    elif self._signame == SIGNAL_NAMES[0]:
      img_pixels = video_pixels.rfsig[frames].astype(np.float32) # new since stacked


    # get finger ground truth
    labels = self._data_labels[sequence][1] # since stacked
    labels = np.array([labels[fing] for fing in FINGER_NAMES]).astype(np.float32) # convert dict to np array

    # transform as necessary
    if self.transform:
      img_pixels = self.transform(img_pixels)
    if self.target_transform:
      labels = self.target_transform(labels)
    return img_pixels.astype(np.float32), labels

My question is how do I use DataLoader on this? I assume I need a custom Sampler but I’ve tried searching and can’t figure out what I need to implement.

I don’t know if creating a custom sampler yielding multiple indices would work, as I would expect to see issues in the fetch method, e.g. here.
However, you might be able to use a BatchSampler as described in this post which would allow you to pass e.g. a list of indices to the Dataset.__getitem__ method. You might need to “encode” this list (e.g. first part is index0, second part is index1 etc.), but it could work.