How to pass arguments to datapipe map function

Here’s an example of datapipe:

    def setup(self, stage: Optional[str]) -> None:
        self.train_dataset, self.val_dataset = (
            .batch(batch_size=self._batch_size, drop_last=self._drop_last).collate()
            .random_split(total_length=self._buffer_length, weights={"train":self._train_fraction,\
                "valid":1-self._train_fraction}, seed=self._seed)

I have defined map_data as a class method so it has access to class variables:

    def map_data(self, x):

This works, but it seems a bit of a hack. The interpreter throws an error if I try to pass extra arguments either like this:

.map(map_data, args)

or like this:


Do I need to create a partial function or something like that?