Hello, I’m currently trying to create three dataset classes to manage my training, validation and test data. The rationale behind this choice is that the data in the three splits need to be manipulated in different ways. I’m not quite sure how to go about this and would appreciate your feedback on my current implementation.
My idea was to create an AbstractDataset
class that inherits both from ABC
, so it can have abstract methods, and from Dataset
. As you can see from the sketch below the three dataset classes I need to implement would all inherit AbstractDataset
and define their own __getitem__
method. Is this a sensible approach or can I achieve my goal in a cleaner way?
class AbstractDataset(ABC, Dataset):
def __init__(self, csv_file: Path, checkpoint: str) -> None:
... <common to all 3 classes> ...
def _tokenize_sample(self, sample: List[str]) -> BatchEncoding:
... <common to all 3 classes> ...
def __len__(self) -> int:
... <common to all 3 classes> ...
def _generate_embedding(self, tokenized: BatchEncoding) -> torch.Tensor:
... <common to all 3 classes> ...
@abstractmethod
def __getitem__(self, index: int):
return NotImplementedError