[Solved] Simplest Way to Increment Multiple Dataset Labels (using the operator+ )?

So we now have the operator+ on datasets which is neat! However I’m looking into merging datasets wherein I require to increment the labels of the datasets., eg:

DS 1 : max_lbl = 10
DS 2: max_lbl = max_lbl(DS1) + current_max_lbl
etc…

What is the most straightforward way to do this?
Ideally I don’t want to iterate over all the data and manually create a new dataset that does this.
Then I got to thinking why not handle it at the loader level, eg:

class ConcatIncrementLabelLoader(object):
    """
    Loader to concatenate multiple loaders.
    Purpose: useful to assemble different existing loaders, possibly
    large-scale loaders as the concatenation operation is done in an
    on-the-fly manner.

    Arguments:
        loaders (iterable): List of loaders to be concatenated
    """

    def __init__(self, loaders):
        assert len(loaders) > 0, 'loaders should not be an empty iterable'
        self.loaders = list(loaders)
        self.max_labels = [find_max_label(loader) for loader in loaders]
        for j in range(len(self.max_labels[1:])):
            self.max_labels[j] += self.max_labels[j - 1]

        self.total_sample_length = 0
        for loader in loaders:
            self.total_sample_length += num_samples_in_loader(loader)

    def __len__(self):
        return num_samples_in_loader

    def __iter__(self):
        loader_idx = np.random.randint(len(self.loaders))
        samples, labels = self.loaders[loader_idx].__iter__().__next__()
        if loader_idx > 0:
            labels = labels + self.max_labels[loader_idx - 1]

        return samples, labels

But this seems nightmarish since I need to now work around the logic of workers and ensuring to yeilding a stop iteration, etc.

Any suggestions would be helpful

Do you know the `Dataset beforehand, i.e. are you creating it?
If so, you could store it as a member variable and just get it:

class MyDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
        self.max_label = torch.max(targets)
        
    def __getitem__(self, index):
        pass
    
    def __len__(self):
        pass

n_classes = 10
data = torch.randn(100, 2)
targets = torch.LongTensor(100).random_(10)

dataset = MyDataset(data, targets)
loader = DataLoader(dataset)
loader.dataset.max_label

Thanks for the response! The type of Dataset is provided via argparse in my scenario generally.

The problem isn’t quite finding the max, but more so incrementing all of the labels by an offset. Most of the datasets I’m using are already in http://pytorch.org/docs/master/torchvision/datasets.html , however as far as I can tell there is no consistent member variable for the labels, eg:

So I can’t simply just do something as follows:

# receive list of datasets from argparse
# example we use two below.
mnist = MNIST('./mnist')
svhn = SVHN('./svhn')
# some other datasets here
datasets = [mnist, svhn, ... ]

# find max and cumsum
dataset_label_max = [torch.max(dataset.train_labels) for dataset in datasets]
for i, _ in range(1, len(dataset_label_max)):
    dataset_label_max[i] += dataset_label_max[ i - 1 ]

# for j in range(1, len(dataset_label_max)):
    datasets[j].train_labels += dataset_label_max[ j - 1 ]

Ok, I understand. A workaround would be to try several possible attributes and return a warning/error if no valid attribute was found:

def get_label_max(dataset):
    max_label = 0
    if hasattr(dataset, 'train_labels'):
        max_label = torch.max(dataset.target_label)
    elif hasattr(dataset, 'labels'):
        max_label = torch.max(dataset.labels)
    elif hasattr(dataset, 'target'):
        max_label = torch.max(dataset.target)
    else:
        print('No labels found! ')
    return max_label

get_label_max(train_dataset)

Good call. This seems like the best solution for now. Will submit a ticket to torchvision to try to get this standardized though. Thanks!