I have implemented a custom pytorch Dataset.
The dataset is initialized with a folder. Each file in the folder contains a long signal that I do some signal processing on to split it into M segments.
Right now for each of the files, the
__getitem__() method generates an NxM matrix where N is a known number of features, and M is the number of extracted segments from that file. I have no way of knowing M in advance, i.e. not until I analyze the entire signal and it’s different per file.
__len__() method currently returns the number of files in the folder the Dataset was initialized with.
What I actually want is to work with these segments as individual samples for my model, not the entire signal.
In other words, I want batch sizes of say 128 segments (so a batch would be of shape 128xN).
Ideally, I would give my dataset to some customized
DataLoader and it would create these batches on the fly by loading one file (with M>128 segments) and taking random batches of 128 segments. When there aren’t enough segments left in the loaded file to fill a batch, a new file should be loaded (in random order) and so on.
I tried looking into the
BatchSampler classes and also at the custom
collate_fn that can be provided to a
DataLoader, but I haven’t found a way to achieve this… Everything seems to expect the number of samples to be known in advance.
So, is there some trick I can use?
Or do I need to resort to simply saving each segment to a separate file and loading them using the standard 1file=1sample based approach?
I don’t know a valid approach without knowing the number of segments in each file.
Don’t you have any way to calculate the number of segments beforehand?
Another approach would be to set an arbitrary high number as length and if you encounter the index to be to high you can choose of two options:
You can modify the index to be in your range (e.g. By using the
% operator. This would result in a fixed epoch length but probably iterating much more often over some parts of the dataset (should not matter if you shuffle anyway)
You can raise the
StopIteration yourself which should be caught inside the Dataloader/automatically inside the loop. This is what I did in early releases of pytorch to write own imagefolder-like datasets (if it is not caught you need to catch it yourself)
Note that for the second approach might cause loggers or other modules which are build upon a fixed number of batches to raise errors.
Thanks. Interesting ideas.
Let’s say I can look at the folder, read some metadata about each file and calculate a rough upper bound
K of the number of segments in all files combined.
The problem is, that given an index in the range
[0,K-1], it’s not possible to know which file to look at to find that index. I could again do some estimations to just pick a file based on the index in a deterministic way, load it, process it and use the
% to get a segment inside it even if it actually doesn’t have enough segments.
However this would cause lots of re-loading of the files and re-running the signal processing algorithms that split it… Each file would be processed many times. It would cause a non-negligible overhead. Probably I could fix it with some caching, but even that might still be slow since I have a lot of data. The reason I was looking for some
DataLoader trick here was to do all the loading and processing lazily, on demand…
How exactly is your dataset structured?
I have uploaded a small snippet here which is not directly related to your problem but should show how to use StopIteration inside an iterable (in this case the dataloaders are simply reset if running out of bounds).
While this approach can be used to create a Dataset with an unknown size (which is good!), it’s still not “lazy enough” just like that because when combined with Samplers and DataLoader I need to make sure to only generate indices of segments within the current file the DataSet has loaded and split.
It now seems possible to perhaps do this by coupling the Dataset and the Sampler, but I think for simplicity i’ll just pre-process all the files and split them in advance.
Thanks for your help
If you could provide more details towards your dataset structure (or maybe some example files) we may be able to help you further