I am trying to generate three separate torch data sets from the torch Dataset class but I am not sure how to do it or if it is possible, owing to __getitem__
's functionality. I would like to have separate torch data sets that I can then pass into data loaders on their own. Can I return 3 separate torch data sets from this class or do I need to make a class for each train, dev, and test set file?
Below is an example:
import torch
from torch.utils.data import DataLoader, Dataset
import pytreebank
import torchtext as tt
from torchtext import datasets
import pandas as pd
# prepare torch data set
class SST5DataSet(torch.utils.data.Dataset):
'''
This prepares the official Stanford Sentiment Treebank
(https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf) for traininig
in PyTorch.
'''
def __init__(self, is_finegrained=True, transform=None):
# initialize flag; set false for binary classification
self.is_finegrained = is_finegrained
# binary categories: 1 and 3 only
self.binary = [1, 3]
# download sst if necessary
tt.datasets.SST.download(root='.data')
# load sst
dataset = pytreebank.load_sst('.data\\sst\\trees')
# prase train test dev
for category in ['train', 'test', 'dev']:
with open('.data\\sst\\trees\\sst_{}.txt'.format(category), 'w') as outfile:
for item in dataset[category]:
outfile.write("{}\t{}\n".format(
item.to_labeled_lines()[0][0],
item.to_labeled_lines()[0][1]))
# initialize train, test dev
self.train = pd.read_csv('.data\\sst\\trees\\sst_train.txt', sep='\t',
header=None, names=['label', 'text'],
encoding='latin-1')
self.dev = pd.read_csv('.data\\sst\\trees\\sst_dev.txt', sep='\t',
header=None, names=['label', 'text'],
encoding='latin-1')
self.test = pd.read_csv('.data\\sst\\trees\\sst_test.txt', sep='\t',
header=None, names=['label', 'text'],
encoding='latin-1')
# filter if fine_grained
if self.is_finegrained is False:
self.train = self.train.loc[self.train['label'].isin(self.binary)].reset_index(drop=True)
self.dev = self.dev.loc[self.dev['label'].isin(self.binary)].reset_index(drop=True)
self.test = self.test.loc[self.test['label'].isin(self.binary)].reset_index(drop=True)
# map to 0, 1
self.train['label'] = self.train['label'].map({1: 0, 3: 1})
self.dev['label'] = self.dev['label'].map({1: 0, 3: 1})
self.test['label'] = self.test['label'].map({1: 0, 3: 1})
# initialize the transform if specified
self.transform = transform
# get len
def __len__(self):
return self.train.shape[0], self.dev.shape[0], self.test.shape[0]
# pull a sample of data
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
# return train, dev, test
train = {'text': self.train.text[idx],
'label': self.train.label[idx],
'idx': idx}
dev = {'text': self.dev.text[idx],
'label': self.dev.label[idx],
'idx': idx}
test = {'text': self.test.text[idx],
'label': self.test.label[idx],
'idx': idx}
return train, dev, test
# concats them all together
out = SST5DataSet(is_finegrained=False)
# check data
for i, batch in enumerate(out):
if i == 0:
break
batch