Create an abstract dataset class

Hello, I’m currently trying to create three dataset classes to manage my training, validation and test data. The rationale behind this choice is that the data in the three splits need to be manipulated in different ways. I’m not quite sure how to go about this and would appreciate your feedback on my current implementation.

My idea was to create an AbstractDataset class that inherits both from ABC, so it can have abstract methods, and from Dataset. As you can see from the sketch below the three dataset classes I need to implement would all inherit AbstractDataset and define their own __getitem__ method. Is this a sensible approach or can I achieve my goal in a cleaner way?

class AbstractDataset(ABC, Dataset):
    def __init__(self, csv_file: Path, checkpoint: str) -> None:
        ... <common to all 3 classes> ...

    def _tokenize_sample(self, sample: List[str]) -> BatchEncoding:
        ... <common to all 3 classes> ...

    def __len__(self) -> int:
        ... <common to all 3 classes> ...

    def _generate_embedding(self, tokenized: BatchEncoding) -> torch.Tensor:
        ... <common to all 3 classes> ...

    def __getitem__(self, index: int):
        return NotImplementedError

This looks like a very sensible implementation if the aim is to develop vastly different __getitem__ methods.

If by ‘manipulated in different ways’ means different transformations on the partitions, then you could just Subset a single Dataset object, copy the underlying datasets and then manually set different transformations using partition1.transforms = ..., partition2.transforms...,