How to apply multiple transforms on a dataset without recreating the dataset every time?

I have a custom dataset that loads data from a bunch of text files. The dataset resembles a standard multi-class supervised classification problem.

However, instead of directly training it to classify into one of N classes, I am trying to train N binary classifiers (one classifier for each class). For this, I am transforming the original dataset into a one-vs-all format (where my target class has a label of 1 and all other classes have the label 0). I have created a
transformation class called OneVsAll for this purpose which takes in a target_category parameter and transforms the dataset into a "target_category vs all" style dataset.

I would like to be able to create the Dataset object just once and apply N such OneVsAll transforms one by one. But the go-to method for applying transforms takes a single transformation (or a composition of transformations) in the constructor of the Dataset, which means if I used it, I would be creating the dataset N times, which I don’t want to do.

Also, I want to avoid applying transformations on every individual sample as I feel it defeats the purpose of the elegant syntax that comes with something like specifying it in the constructor.

How should I go about this?

import os
import random

from torch.utils.data import Dataset, DataLoader

from gensim.models import word2vec


class SEMCATDataset(Dataset):

    # Mapping from canonical category names to filenames
    WORDS_DIR = '../Categories'
    CATEGORY_FILES = {
        f.split('-')[0]: f for f in os.listdir(WORDS_DIR)
    }

    def __init__(self, transform=None):
        self.transform = transform
        self.words = []
        for category, words_file in SEMCATDataset.CATEGORY_FILES.iteritems():
            self.words.extend([{
                'word': word,
                'category': category
            } for word in self._get_words_for_category(words_file)])

    def _get_words_for_category(self, words_file):
        words_file = os.path.join(
            SEMCATDataset.WORDS_DIR, words_file)
        with open(words_file, 'r') as f:
            words = map(lambda x: x.strip(), f.readlines())
            return words

    def __len__(self):
        return len(self.target_words) + len(self.non_target_words)

    def __getitem__(self, idx):
        sample = self.words[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample


class OneVsAll(object):

    def __init__(self, target_category):
        self.target_category = target_category

    def __call__(self, sample):
        return {
            'word': sample['word'],
            'category': int(sample['category'] == self.target_category)
        }

You could either hold on to the OneVsAll object and change its target_category member or just keep track of a target_category in the dataset, use it in __getitem__ and change that when you want to have a different category.

That said, if you prefer not to modify the dataset, you could also just return the category and do the comparison cat == target_cat on the cat Tensor in the training loop itself and claim that you’re following Explicit is better than implicit Python-Zen.

Best regards

Thomas

1 Like

Thanks @tom. I think I’ll go ahead with the second approach. That sounds like the best option even though it means foregoing PyTorch’s Transform route.