SQL datapipe for sequence data

Hi everyone,
I’m working on a project where i have sequence data stored in a table (called pulsemap) SQL database and info about the sequence stored in another table (called mc_truth). The way the sequences is stored is that each row is indexed by a event_no to identify which event/sequence it belongs to. I input the sequence into a transformer encoder and try to reconstruct the variables in mc_truth. if you want to know more about the data then its very similar to this kaggle competition.

I have a simple version where I create a dataset which uses sqlite3 to query the database in getitem and uses a custom collate function to pad all the sequences to the same size.

the simple working version with the transformer is in this Google colab here. which downloads example data from a github repo.

I want to convert this dataset and dataloader to a datapipe and dataloader2 so that I can use
torchdata.datapipes.iter.MaxTokenBucketizer to make the batches since with the simple implementation is very inefficient with padding and generally make it more efficient since I have alot of data.

from what I could find I need to implement something like this but unsure how

@functional_datapipe("read_SQL")
class _ParseSQLData(IterDataPipe):

    def __init__(self, source_datapipe) -> None:
        self.source_datapipe = source_datapipe

    def __iter__(self):
        ...
        yield _features, _truth

Any help on how make this datapipe and dataloader. or comments on how to make it more efficient would be much appreciated

My simple dataset version without the transformer and training:

class SimpleSQLDataset(torch.utils.data.Dataset):
  def __init__(self, 
               db_path, 
               event_no_list,
               pulsemap,
               input_cols,
               target_cols,
               truth_table = "mc_truth"
               ):
    self.db_path = db_path
    self.event_no_list = event_no_list
    self.pulsemap = pulsemap
    self.input_cols = input_cols
    self.target_cols = target_cols
    self.truth_table = truth_table

    if isinstance(input_cols, list):
      self.input_cols_str = ", ".join(input_cols)
    else:
      self.input_cols_str = input_cols

    if isinstance(target_cols, list):
      self.target_cols_str = ", ".join(target_cols)
    else:
      self.target_cols_str = target_cols

    self.data_len = len(event_no_list)
  def __getitem__(self, index):
    event_no = self.event_no_list[index]
    with sqlite3.connect(self.db_path) as conn:
      features = torch.Tensor(conn.execute(f"SELECT {self.input_cols_str} FROM {self.pulsemap} WHERE event_no == {event_no}").fetchall())
      truth = torch.Tensor(conn.execute(f"SELECT {self.target_cols_str} FROM {self.truth_table} WHERE event_no == {event_no}").fetchall())

    return features, truth
  def __len__(self):
    return self.data_len

def pad_collate(batch):
  (xx, y) = zip(*batch)
  x_lens = [len(x) for x in xx]
  xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)

  pad_mask = torch.zeros_like(xx_pad[:, :, 0]).type(torch.bool)
  for i, length in enumerate(x_lens):
    pad_mask[i, length:] = True

  return xx_pad, torch.tensor(y), pad_mask

simpledataset = SimpleSQLDataset( 
               db_path = prometheus_path, 
               event_no_list = np.arange(10),
               pulsemap = "total",
               truth_table = "mc_truth",
               input_cols = ["sensor_pos_x","sensor_pos_y","sensor_pos_z","t"],
               target_cols = "injection_energy",
               )

dataloader = DataLoader(dataset=simpledataset, batch_size = 4, collate_fn = pad_collate)