Implement Nested Dataset

Hello, I am working on few-shot learning framework.
In few-shot learning, each batch should contain KN image, label pair, where K is the number of classes and N is the number of shots(images).

In this reason, I want to design dataset and dataloader as bellow,
In each iteration, dataloader iterates with class, and dataset for each class iterates with items.

class FSLDataset(Dataset): each item is ClassDataset which has data with a same class.
class ClassDataset(Dataset): each item is pair of images and label.

However, torhch.utils.data.Dataset only supports with tensors, numpy arrays, numbers, dicts or lists.

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class ‘main.ClassDataset’>

Do you have any idea on this design?

I find my own solution.

class NestedMetaDataset(torch.utils.data.Dataset):
    def __init__(self, load_dataset_class, classes, batch_size, **args):
        #self.datasets = [ClassDataset(load_dataset_class, cls, **args) for cls in classes]
        self.datasets = []
        self.len = 0

        for cls in classes:
            cls_data = ClassDataset(load_dataset_class, cls, **args)
            self.len += len(cls_data)
            self.datasets.append(iter(cls_data))

        self.class_num = len(classes)
    
    def __len__(self):
        return self.len
        
    def __getitem__(self, index):
        class_id = index % len(self.datasets)
        dataset = self.datasets[class_id]
        item = next(dataset)
        
        return item
        
class ClassDataset:
    def __init__(self, load_dataset_class, class_name, **kargs):
        self.items = load_dataset_class(class_name=class_name, **kargs)

    
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, index):
        item = self.items[index]
        
        return item
2 Likes