I want to Chain
two datasets together, one of which is a normal Dataset and the other is an IterableDataset. But Chain only accepts IterableDatasets.
So, I thought I’d just create a simple Iterable version of my original map-style dataset… but it’s not working. Sample code below.
I see a few questions related to this, such as this and this SO question with no answer, but no answers.
How do we do this? Here’s my attempt:
class MapStyleDataset(torch.utils.data.DataSet):
def __init__(self, path, **kwargs):
super().__init__()
self.filenames = glob.glob(path)
def __getitem__(self, idx):
data = somehow_read_from( self.filenames[idx] )
return data
def __len__(self):
return len(self.filenames)
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, *args, **kwargs):
super().__init__()
self.this = MapStyleDataset(args, kwargs)
self.len = len(self.this)
def __iter__(self):
return self.this.__getitem__(random.randint(0, self.len))
Let’s assume MapStyleDataset works fine. You can do next(iter(....))
on it properly.
But when I try to do next(iter(....))
on the Iterable dataset, over and over I get error messages that read:
TypeError: iter() returned non-iterator of type 'Tensor'
How do do this conversion properly? Thanks.