So we now have the operator+ on datasets which is neat! However I’m looking into merging datasets wherein I require to increment the labels of the datasets., eg:
What is the most straightforward way to do this?
Ideally I don’t want to iterate over all the data and manually create a new dataset that does this.
Then I got to thinking why not handle it at the loader level, eg:
class ConcatIncrementLabelLoader(object):
"""
Loader to concatenate multiple loaders.
Purpose: useful to assemble different existing loaders, possibly
large-scale loaders as the concatenation operation is done in an
on-the-fly manner.
Arguments:
loaders (iterable): List of loaders to be concatenated
"""
def __init__(self, loaders):
assert len(loaders) > 0, 'loaders should not be an empty iterable'
self.loaders = list(loaders)
self.max_labels = [find_max_label(loader) for loader in loaders]
for j in range(len(self.max_labels[1:])):
self.max_labels[j] += self.max_labels[j - 1]
self.total_sample_length = 0
for loader in loaders:
self.total_sample_length += num_samples_in_loader(loader)
def __len__(self):
return num_samples_in_loader
def __iter__(self):
loader_idx = np.random.randint(len(self.loaders))
samples, labels = self.loaders[loader_idx].__iter__().__next__()
if loader_idx > 0:
labels = labels + self.max_labels[loader_idx - 1]
return samples, labels
But this seems nightmarish since I need to now work around the logic of workers and ensuring to yeilding a stop iteration, etc.
Thanks for the response! The type of Dataset is provided via argparse in my scenario generally.
The problem isn’t quite finding the max, but more so incrementing all of the labels by an offset. Most of the datasets I’m using are already in http://pytorch.org/docs/master/torchvision/datasets.html , however as far as I can tell there is no consistent member variable for the labels, eg:
# receive list of datasets from argparse
# example we use two below.
mnist = MNIST('./mnist')
svhn = SVHN('./svhn')
# some other datasets here
datasets = [mnist, svhn, ... ]
# find max and cumsum
dataset_label_max = [torch.max(dataset.train_labels) for dataset in datasets]
for i, _ in range(1, len(dataset_label_max)):
dataset_label_max[i] += dataset_label_max[ i - 1 ]
# for j in range(1, len(dataset_label_max)):
datasets[j].train_labels += dataset_label_max[ j - 1 ]