Loading data from SQL with IterableDataset

I worked on a project recently where I needed to pull data directly from a SQL table via sqlalchemy/odbc. I found it quite challenging at the time, so I thought I’d share my approach here in case anybody finds themselves in a similar situation, or would like to add.

The repo demonstrates how to fine-tune a hugging face transformer model for text classification (I chose QNLI for now). It can easily be configured for other subsets of the GLUE benchmark dataset.

For the implementation of the dataset class, I chose to subclass IterableDataset (i.e. torch.utils.data.IterableDataset; torch.utils.data — PyTorch 1.12 documentation). Quoting the PyTorch documentation: “An iterable-style dataset is an instance of a subclass of IterableDataset that implements the iter() protocol, and represents an iterable over data samples. This type of datasets is particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data.” This contrasts with the standard Dataset class, which implements a method “getitem(index)” that for selecting a sample from the dataset by its index.

Perhaps a downside of this approach is that this class cannot easily be combined with PyTorch Samplers. Again quoting the PyTorch documentation: “For iterable-style datasets, data loading order is entirely controlled by the user-defined iterable. This allows easier implementations of chunk-reading and dynamic batch size (e.g., by yielding a batched sample at each time).”

The repo is pretty minimal at this point, but might still be useful for some.