Hi everyone,
I’m working on a project where i have sequence data stored in a table (called pulsemap) SQL database and info about the sequence stored in another table (called mc_truth). The way the sequences is stored is that each row is indexed by a event_no to identify which event/sequence it belongs to. I input the sequence into a transformer encoder and try to reconstruct the variables in mc_truth. if you want to know more about the data then its very similar to this kaggle competition.
I have a simple version where I create a dataset which uses sqlite3 to query the database in getitem and uses a custom collate function to pad all the sequences to the same size.
the simple working version with the transformer is in this Google colab here. which downloads example data from a github repo.
I want to convert this dataset and dataloader to a datapipe and dataloader2 so that I can use
torchdata.datapipes.iter.MaxTokenBucketizer to make the batches since with the simple implementation is very inefficient with padding and generally make it more efficient since I have alot of data.
from what I could find I need to implement something like this but unsure how
@functional_datapipe("read_SQL")
class _ParseSQLData(IterDataPipe):
def __init__(self, source_datapipe) -> None:
self.source_datapipe = source_datapipe
def __iter__(self):
...
yield _features, _truth
Any help on how make this datapipe and dataloader. or comments on how to make it more efficient would be much appreciated
My simple dataset version without the transformer and training:
class SimpleSQLDataset(torch.utils.data.Dataset):
def __init__(self,
db_path,
event_no_list,
pulsemap,
input_cols,
target_cols,
truth_table = "mc_truth"
):
self.db_path = db_path
self.event_no_list = event_no_list
self.pulsemap = pulsemap
self.input_cols = input_cols
self.target_cols = target_cols
self.truth_table = truth_table
if isinstance(input_cols, list):
self.input_cols_str = ", ".join(input_cols)
else:
self.input_cols_str = input_cols
if isinstance(target_cols, list):
self.target_cols_str = ", ".join(target_cols)
else:
self.target_cols_str = target_cols
self.data_len = len(event_no_list)
def __getitem__(self, index):
event_no = self.event_no_list[index]
with sqlite3.connect(self.db_path) as conn:
features = torch.Tensor(conn.execute(f"SELECT {self.input_cols_str} FROM {self.pulsemap} WHERE event_no == {event_no}").fetchall())
truth = torch.Tensor(conn.execute(f"SELECT {self.target_cols_str} FROM {self.truth_table} WHERE event_no == {event_no}").fetchall())
return features, truth
def __len__(self):
return self.data_len
def pad_collate(batch):
(xx, y) = zip(*batch)
x_lens = [len(x) for x in xx]
xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)
pad_mask = torch.zeros_like(xx_pad[:, :, 0]).type(torch.bool)
for i, length in enumerate(x_lens):
pad_mask[i, length:] = True
return xx_pad, torch.tensor(y), pad_mask
simpledataset = SimpleSQLDataset(
db_path = prometheus_path,
event_no_list = np.arange(10),
pulsemap = "total",
truth_table = "mc_truth",
input_cols = ["sensor_pos_x","sensor_pos_y","sensor_pos_z","t"],
target_cols = "injection_energy",
)
dataloader = DataLoader(dataset=simpledataset, batch_size = 4, collate_fn = pad_collate)