How to efficiently load COCO dataset to dataloader

Hi, I have a problem with loading COCO data to data loader. When I am doing it my RAM is used in 100% (500 GB (sic!)).
I load my dataset as here:

class LoadDataset(Dataset):
    def __init__(self):
        self.images = []
        self.targets = []
        img_path, ann_path = (
            "path_to_images",
            "path_to_annotations",
        )
        coco_ds = torchvision.datasets.CocoDetection(img_path, ann_path)
        for i in range(0, len(coco_ds)):
            img, ann = coco_ds[i]
            for a in ann:
                width = a["bbox"][2]
                height = a["bbox"][3]
                image_size = width * height
                if image_size > 10000:  # I want only high quality images
                    for t in targets:
                        self.targets.append(t)
                    for image in images:
                        self.images.append(image)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        target = self.targets[idx]

        target["some_parameter"] = torch.as_tensor(some_value)  # I added another 'head' to RCNN model

        return (
            img,
            target,
        )

And then I load data to dataloader like:

train_loader = DataLoader(LoadDataset(), batch_size=24, shuffle=True, num_workers=0)

The problem is in your init method. You load everything here. Don’t do that.

Move the image loading logic to getitem because that is the method for loading. Init is only for creating and storing the file paths and labels.

And also can’t you use the CocoDataset directly? Since you can just instantiate the dataset and pass it to a dataloader.

class CocoDetection(data.Dataset):
    """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.

    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """

    def __init__(self, root, annFile, transform=None, target_transform=None):
        from pycocotools.coco import COCO
        self.root = root
        self.coco = COCO(annFile)
        self.ids = list(self.coco.imgs.keys())
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


    def __len__(self):
        return len(self.ids)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

Thanks for reply!
I can’t use CocoDataset because I want to use only images which fulfil my criterias. Also I have to modify targets by adding extra head to them.

Actually, where should I made lists images and targets? I understand that in the __init__ but according to your words I can’t fill them, because I shouldn’t load there them elements (to have elements of lists I have to loads them at first)?
And isn’t __getitem__ loaded in every iteration of model? So if I have 4k iterations in 100 epoch it’s going to be loaded 400k times and every time size will be same as if I put it to __getitem__. Am I right?

The role of the dataloader is to generate batches (and transforms if you’ve passed those) in parallel on the CPU from the dataset and that’s all. It gives you batches of data to iterate over which is roughly equal to number_of_sample/batch_size. How many times you loop over it depends on how you’ve written your training logic. However, in DL when we iterate over all the samples once it is called a single epoch. So if you have n epochs your dataset will be iterated n times using the batches generated by the dataloader.

This is an awesome tutorial on Custom Datasets:

To give you some direction, I’ve written some inheritance logic.

from torchvision.datasets import CocoDetection


class CustomDataset(CocoDetection):
    def __init__(self,
                 root,
                 annFile,
                 transform=None,
                 target_transform=None) -> None:
        super().__init__(root, annFile, transform, target_transform)
        self.ids = [
            "A list of all the file names which satisfy your criteria "
        ]
        # You can get the above list by applying your filtering logic to
        # this list :list(self.coco.imgs.keys()) So this would only be have
        # to be done only once.
        # Save it to a text file. This file will now contain the names of
        # images that match your criteria
        # Load that file contents in the init function into self.ids
        # the length would automatically be correct

    def __getitem__(self, index: int):
        img, target = super().__getitem__(index)
        # do whatever you want
        return img, target

Hope this helps!