Faster way to get target counts from torch.utils.data.Dataset?

I am getting to into the habit of creating custom torch.utils.data.Dataset classes to load and prepare my data – its text data so I am tokenizing it “on the fly”.

The problem I have, is that my data is imbalanced and I want to use WeightedRandomSampler which requires that I know the distribution of my target class. This is unknown, owing to the fact that I am using the random_split function to split my data:

train_size = int(0.8 * len(csv_dataset))
valid_size = int(0.1 * len(csv_dataset))
train_ds, valid_ds, test_ds = torch.utils.data.random_split(csv_dataset, [train_size, valid_size, valid_size+1])

A single loop through the data set takes a long time because it contains quite a bit of information (owing to the transformations done). Is there another way around this?

I can recommend the new v1 datasets library by huggingface. This way you only have to tokenized once and it will be cached for all next times you want to run the script.

Thanks so much – I will look into it, although I am still hopeful for a PyTorch strict solution.

You can design your own ‘torch.utils.data.Dataset’, for example, load data based on .txt file in which each line include image_path, label and label_count, so you can load image,label and label_count simultaneously.

Ah right, that is a possibility. I forgot to mention that I am loading / building my data from a CSV file into my torch.utils.data.Dataset.

Update: The problem I have now is more specifically with the the Subset class.

In my CSV Class, I can get the distribution of my target easily by:

    def __len_target__(self):
        return np.bincount(self.data_frame.target)

The problem is, this method is inoperable when I use Subclass and torch.utils.data.random_split to lazily load the data and generate training, validation, and test sets.

Last update: I think I have found a quick solution. Here is what I did in case it helps anyone else in the future.

First, prepare a dataset class for CSV:

# Create Dataset
class CSVDataset(Dataset):
    """Lazy Loading."""

    def __init__(self, csv_file, text_col, target, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data_frame = pd.read_csv(csv_file)
        self.text_features = text_col
        self.target = target
        self.transform = transform

        # encode outcome
        self.data_frame[target] = LabelEncoder().fit_transform(self.data_frame[target])

    def __len__(self):
        return len(self.data_frame)

    def __get_target__(self):
        return  self.data_frame.target

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        text = self.data_frame.iloc[idx][self.text_features]
        target = self.data_frame.iloc[idx][self.target]

        sample = {'text': text.values, 'target': target.values, 'idx': torch.tensor(idx)}

        if self.transform:
            sample = self.transform(sample)

        return sample

csv_dataset = CSVDataset(csv_file='C:\\Users\\14348\\Desktop\\df2.csv',
                         text_col=['body'],
                         target=['target'])

Next, use random_split to generate train, valid and test sets.


train_size = int(0.8 * len(csv_dataset))
valid_size = int(0.1 * len(csv_dataset))

train_ds, valid_ds, test_ds = torch.utils.data.random_split(csv_dataset, [train_size, valid_size+1, valid_size+1])

Next, pull the indices that random_split used to create the subset data sets:

train_indices = train_ds.indices
valid_indices = valid_ds.indices
test_indices = test_ds.indices

Lastly, run np.bincount on the full data set to get target distributions.

train_ds_bincount = np.bincount(csv_dataset.__get_target__()[indices1])
valid_ds_bincount = np.bincount(csv_dataset.__get_target__()[indices2])
test_ds_bincount = np.bincount(csv_dataset.__get_target__()[indices3])

Now we know the label distribution of each lazily loaded data set so that torch.utils.data.WeightedRandomSampler will work accordingly.