So we now have the
operator+ on datasets which is neat! However I’m looking into merging datasets wherein I require to increment the labels of the datasets., eg:
DS 1 : max_lbl = 10
DS 2: max_lbl = max_lbl(DS1) + current_max_lbl
What is the most straightforward way to do this?
Ideally I don’t want to iterate over all the data and manually create a new dataset that does this.
Then I got to thinking why not handle it at the loader level, eg:
Loader to concatenate multiple loaders.
Purpose: useful to assemble different existing loaders, possibly
large-scale loaders as the concatenation operation is done in an
loaders (iterable): List of loaders to be concatenated
def __init__(self, loaders):
assert len(loaders) > 0, 'loaders should not be an empty iterable'
self.loaders = list(loaders)
self.max_labels = [find_max_label(loader) for loader in loaders]
for j in range(len(self.max_labels[1:])):
self.max_labels[j] += self.max_labels[j - 1]
self.total_sample_length = 0
for loader in loaders:
self.total_sample_length += num_samples_in_loader(loader)
loader_idx = np.random.randint(len(self.loaders))
samples, labels = self.loaders[loader_idx].__iter__().__next__()
if loader_idx > 0:
labels = labels + self.max_labels[loader_idx - 1]
return samples, labels
But this seems nightmarish since I need to now work around the logic of workers and ensuring to yeilding a stop iteration, etc.
Any suggestions would be helpful
Do you know the `Dataset beforehand, i.e. are you creating it?
If so, you could store it as a member variable and just get it:
def __init__(self, data, targets):
self.data = data
self.targets = targets
self.max_label = torch.max(targets)
def __getitem__(self, index):
n_classes = 10
data = torch.randn(100, 2)
targets = torch.LongTensor(100).random_(10)
dataset = MyDataset(data, targets)
loader = DataLoader(dataset)
Thanks for the response! The type of
Dataset is provided via argparse in my scenario generally.
The problem isn’t quite finding the max, but more so incrementing all of the labels by an offset. Most of the datasets I’m using are already in http://pytorch.org/docs/master/torchvision/datasets.html , however as far as I can tell there is no consistent member variable for the labels, eg:
So I can’t simply just do something as follows:
# receive list of datasets from argparse
# example we use two below.
mnist = MNIST('./mnist')
svhn = SVHN('./svhn')
# some other datasets here
datasets = [mnist, svhn, ... ]
# find max and cumsum
dataset_label_max = [torch.max(dataset.train_labels) for dataset in datasets]
for i, _ in range(1, len(dataset_label_max)):
dataset_label_max[i] += dataset_label_max[ i - 1 ]
# for j in range(1, len(dataset_label_max)):
datasets[j].train_labels += dataset_label_max[ j - 1 ]
Ok, I understand. A workaround would be to try several possible attributes and return a warning/error if no valid attribute was found:
max_label = 0
if hasattr(dataset, 'train_labels'):
max_label = torch.max(dataset.target_label)
elif hasattr(dataset, 'labels'):
max_label = torch.max(dataset.labels)
elif hasattr(dataset, 'target'):
max_label = torch.max(dataset.target)
print('No labels found! ')
Good call. This seems like the best solution for now. Will submit a ticket to torchvision to try to get this standardized though. Thanks!