If anyone is interested, I’ve implemented an simplified version. It works the same way, but also support Seq2Seq datasets, i.e., where the inputs and targets are sequences. This means that each batch will only contain sample with the same combination of input and target length to make any padding unnecessary.
import numpy as np
from torch.utils.data import Dataset, Sampler
class BaseDataset(Dataset):
def __init__(self, inputs, targets):
self.inputs = inputs
self.targets = targets
def __len__(self):
return len(self.inputs)
def __getitem__(self, index):
if self.targets is None:
return np.asarray(self.inputs[index])
else:
return np.asarray(self.inputs[index]), np.asarray(self.targets[index])
class EqualLengthsBatchSampler(Sampler):
def __init__(self, batch_size, inputs, targets):
# Throw an error if the number of inputs and targets don't match
if targets is not None:
if len(inputs) != len(targets):
raise Exception("[BucketBatchSampler] inputs and targets have different sizes")
# Remember batch size and number of samples
self.batch_size, self.num_samples = batch_size, len(inputs)
self.unique_length_pairs = set()
self.lengths_to_samples = {}
for i in range(0, len(inputs)):
len_input = len(inputs[i])
try:
# Fails if targets[i] is not a sequence but a scalar (e.g., a class label)
len_target = len(targets[i])
except:
# In case of failure, we just the length to 1 (value doesn't matter, it only needs to be a constant)
len_target = 1
# Add length pair to set of all seen pairs
self.unique_length_pairs.add((len_input, len_target))
# For each lengths pair, keep track of which sample indices for this pair
# E.g.: self.lengths_to_sample = { (4,5): [3,5,11], (5,5): [1,2,9], ...}
if (len_input, len_target) in self.lengths_to_samples:
self.lengths_to_samples[(len_input, len_target)].append(i)
else:
self.lengths_to_samples[(len_input, len_target)] = [i]
# Convert set of unique length pairs to a list so we can shuffle it later
self.unique_length_pairs = list(self.unique_length_pairs)
def __len__(self):
return self.num_samples
def __iter__(self):
# Shuffle list of unique length pairs
np.random.shuffle(self.unique_length_pairs)
# Iterate over all possible sentence length pairs
for length_pair in self.unique_length_pairs:
# Get indices of all samples for the current length pairs
# for example, all indices with a lenght pair of (8,7)
sequence_indices = self.lengths_to_samples[length_pair]
sequence_indices = np.array(sequence_indices)
# Shuffle array of sequence indices
np.random.shuffle(sequence_indices)
# Compute the number of batches
num_batches = np.ceil(len(sequence_indices) / self.batch_size)
# Loop over all possible batches
for batch_indices in np.array_split(sequence_indices, num_batches):
yield np.asarray(batch_indices)
I just implemented it for a tutorial, but maybe it’s useful for others.