Loading dataset to cuda

Hello everyone!

I was wondering is it possible to save a list or a PIL image to GPU memory without converting it to a tensor?

My problem is, I save my entire trainign data on RAM, which is okay; however, RAM is not enough to store validation data as well.

What I had before is, I was saving all data to RAM or cda and applying transforms at the construtor, so saving data to cpu or cuda was fine. However now, I’m applying the transforms on the fly so at constructor level images are still of type PIL Image and I can’t save them except on cpu.

I would open the images at getitem, but I was saving 3.5 minutes per epoch when I was loading entire dataset to RAM or cuda.

The Dataset class is below. Also, I would love any advice to improve it as I’m fairly new to this.

Thank you!

class Data(Dataset):
    def __init__(self, class_list, images_path, trimap_path, img_size, ratio_labelled, device, unlabaled_flag=False):
    

        self.class_list = class_list
        self.images_path = images_path
        self.trimap_path = trimap_path
        self.unlabaled_flag = unlabaled_flag
        self.device = device
        self.ratio_labelled = ratio_labelled
        self.img_size = img_size
        
        if unlabaled_flag == True:

            num_labelled_classes = int(self.ratio_labelled * len(self.class_list))
            
            labelled_names = random.sample(self.class_list, num_labelled_classes)

            self.labelled_class_names = [c[0] for c in class_list if c in labelled_names]

        self.data = []

        for c in class_list:

            image = os.path.join(self.images_path, c[0] + ".jpg")
            trimap_image = os.path.join(self.trimap_path, c[0] + ".png")

            RGB_image = Image.open(image).convert("RGB")
            trimap_image = Image.open(trimap_image)

            self.data.append([RGB_image,trimap_image])

    
    @staticmethod
    def augmentation(image, trimap, img_size,flag = "train"):
        
        resize= transforms.Resize((img_size, img_size))
        to_tensor = transforms.ToTensor()
        norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
        weak_color_noise = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05)
        strong_color_noise = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
      

        if flag == "train":
            
            i, j, h, w = transforms.RandomResizedCrop.get_params(image,scale=(0.8,1.0),ratio=(3/4,4/3))
            image = F.resized_crop(image,i,j,h,w,(img_size,img_size))
            trimap = F.resized_crop(trimap,i,j,h,w,(img_size,img_size))

            if random.random() < 0.3:
                angle = random.randint(-5,5)
                image = F.rotate(image,angle)
                trimap = F.rotate(trimap,angle)
    
            if random.random() < 0.5:
                Hflip = transforms.RandomHorizontalFlip(p=1)
                image = Hflip(image)
                trimap = Hflip(trimap)
            
            weak_image = weak_color_noise(image)
            strong_image = strong_color_noise(image)
                
            if random.random() < 0.5:
                gauss_blur = transforms.GaussianBlur(3,sigma=(0.1,2.0))
                strong_image = gauss_blur(strong_image)

            weak_image = norm(to_tensor(weak_image))
            strong_image = norm(to_tensor(strong_image))
            trimap = to_tensor(trimap)

            return weak_image, strong_image, trimap
    
        elif flag == "val":

            val_image = resize(image)
            val_trimap = resize(trimap)
    
            val_image = norm(to_tensor(val_image))
            val_trimap = to_tensor(val_trimap)
            
            return val_image, val_trimap
        
    @staticmethod
    def mask_blend(trimap_image):
        
        trimap_image[trimap_image == (2.0 / 255)] = 0.0
        trimap_image[trimap_image == (3.0 / 255)] = 0.0     
        trimap_image[trimap_image == (1.0 / 255)] = 1.0
        
        return trimap_image

    def __len__(self):

        return len(self.class_list)

    def __getitem__(self, idx):

        current_image, current_trimap =  self.data[idx]

        if self.unlabaled_flag:
    
                dataset_image, dataset_image2, trimap_image = self.augmentation(current_image, current_trimap, self.img_size, "train")
    
                if self.class_list[idx][0] not in self.labelled_class_names:
                    trimap_image = torch.full(trimap_image.size(), -1)
                    
                else:
                    trimap_image = self.mask_blend(trimap_image)

                return dataset_image, dataset_image2, trimap_image
  
        else:

                dataset_image, trimap_image = self.augmentation(current_image, current_trimap, self.img_size, "val")
                trimap_image = self.mask_blend(trimap_image)

                return dataset_image, trimap_image

No, since PIL.Images use numpy to store their data which does not support GPUs. You would thus need to transform the data to a tensor (e.g. via torch.from_numpy) and could move it to the GPU afterwards.

1 Like