Here’s an example of datapipe:
def setup(self, stage: Optional[str]) -> None:
self.train_dataset, self.val_dataset = (
FileLister([self._path])
.filter(filter_fn=json_filter)
.open_files(mode='rt')
.readlines()
.shuffle(buffer_size=self._buffer_length)
.map(load_json)
.map(self.map_data)
.batch(batch_size=self._batch_size, drop_last=self._drop_last).collate()
.random_split(total_length=self._buffer_length, weights={"train":self._train_fraction,\
"valid":1-self._train_fraction}, seed=self._seed)
)
I have defined map_data
as a class method so it has access to class variables:
def map_data(self, x):
This works, but it seems a bit of a hack. The interpreter throws an error if I try to pass extra arguments either like this:
.map(map_data, args)
or like this:
.map(map_data(args))
Do I need to create a partial function or something like that?