I have a custom dataset that loads data from a bunch of text files. The dataset resembles a standard multi-class supervised classification problem.
However, instead of directly training it to classify into one of N classes, I am trying to train N binary classifiers (one classifier for each class). For this, I am transforming the original dataset into a one-vs-all format (where my target class has a label of 1 and all other classes have the label 0). I have created a
transformation class called OneVsAll
for this purpose which takes in a target_category
parameter and transforms the dataset into a "target_category
vs all" style dataset.
I would like to be able to create the Dataset object just once and apply N such OneVsAll
transforms one by one. But the go-to method for applying transforms takes a single transformation (or a composition of transformations) in the constructor of the Dataset, which means if I used it, I would be creating the dataset N times, which I don’t want to do.
Also, I want to avoid applying transformations on every individual sample as I feel it defeats the purpose of the elegant syntax that comes with something like specifying it in the constructor.
How should I go about this?
import os
import random
from torch.utils.data import Dataset, DataLoader
from gensim.models import word2vec
class SEMCATDataset(Dataset):
# Mapping from canonical category names to filenames
WORDS_DIR = '../Categories'
CATEGORY_FILES = {
f.split('-')[0]: f for f in os.listdir(WORDS_DIR)
}
def __init__(self, transform=None):
self.transform = transform
self.words = []
for category, words_file in SEMCATDataset.CATEGORY_FILES.iteritems():
self.words.extend([{
'word': word,
'category': category
} for word in self._get_words_for_category(words_file)])
def _get_words_for_category(self, words_file):
words_file = os.path.join(
SEMCATDataset.WORDS_DIR, words_file)
with open(words_file, 'r') as f:
words = map(lambda x: x.strip(), f.readlines())
return words
def __len__(self):
return len(self.target_words) + len(self.non_target_words)
def __getitem__(self, idx):
sample = self.words[idx]
if self.transform:
sample = self.transform(sample)
return sample
class OneVsAll(object):
def __init__(self, target_category):
self.target_category = target_category
def __call__(self, sample):
return {
'word': sample['word'],
'category': int(sample['category'] == self.target_category)
}