Filter out items from Dataset? "filter_pred" for any dataset

Hello,

For some datasets in torchtext, I see the parameter filter_pred which filter out part of the dataset.

I’m using torchtext.Datasets which gives an IterableDataset. This iterator you can pass to the Dataloader. Is it possible to simply apply a filter like filter_pred?

This is my solution, which I don’t like, if you have something better, please share.

class MyIterableDataset(IterableDataset):
    def __init__(self, iterableDataset, pred_filter=None):
        self.iterableDataset = iterableDataset
        self.pred_filter = pred_filter
    
    def __iter__(self):
        def it(myDataIter):
            for x in myDataIter:
                if self.pred_filter and not self.pred_filter(x):
                    continue
                yield x
        myiter = iter(it(self.iterableDataset))
        return myiter

Does this work for you?

class MyIterableDataset(IterableDataset):

    def __init__(self, iterableDataset, pred_filter=None):
        self.iterableDataset = iterableDataset
        self.pred_filter = pred_filter

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            item = next(self.iterableDataset)
            if self.pred_filter is None or self.pred_filter(item):
                return item