Hello everyone!
I was wondering is it possible to save a list or a PIL image to GPU memory without converting it to a tensor?
My problem is, I save my entire trainign data on RAM, which is okay; however, RAM is not enough to store validation data as well.
What I had before is, I was saving all data to RAM or cda and applying transforms at the construtor, so saving data to cpu or cuda was fine. However now, I’m applying the transforms on the fly so at constructor level images are still of type PIL Image and I can’t save them except on cpu.
I would open the images at getitem, but I was saving 3.5 minutes per epoch when I was loading entire dataset to RAM or cuda.
The Dataset class is below. Also, I would love any advice to improve it as I’m fairly new to this.
Thank you!
class Data(Dataset):
def __init__(self, class_list, images_path, trimap_path, img_size, ratio_labelled, device, unlabaled_flag=False):
self.class_list = class_list
self.images_path = images_path
self.trimap_path = trimap_path
self.unlabaled_flag = unlabaled_flag
self.device = device
self.ratio_labelled = ratio_labelled
self.img_size = img_size
if unlabaled_flag == True:
num_labelled_classes = int(self.ratio_labelled * len(self.class_list))
labelled_names = random.sample(self.class_list, num_labelled_classes)
self.labelled_class_names = [c[0] for c in class_list if c in labelled_names]
self.data = []
for c in class_list:
image = os.path.join(self.images_path, c[0] + ".jpg")
trimap_image = os.path.join(self.trimap_path, c[0] + ".png")
RGB_image = Image.open(image).convert("RGB")
trimap_image = Image.open(trimap_image)
self.data.append([RGB_image,trimap_image])
@staticmethod
def augmentation(image, trimap, img_size,flag = "train"):
resize= transforms.Resize((img_size, img_size))
to_tensor = transforms.ToTensor()
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
weak_color_noise = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05)
strong_color_noise = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
if flag == "train":
i, j, h, w = transforms.RandomResizedCrop.get_params(image,scale=(0.8,1.0),ratio=(3/4,4/3))
image = F.resized_crop(image,i,j,h,w,(img_size,img_size))
trimap = F.resized_crop(trimap,i,j,h,w,(img_size,img_size))
if random.random() < 0.3:
angle = random.randint(-5,5)
image = F.rotate(image,angle)
trimap = F.rotate(trimap,angle)
if random.random() < 0.5:
Hflip = transforms.RandomHorizontalFlip(p=1)
image = Hflip(image)
trimap = Hflip(trimap)
weak_image = weak_color_noise(image)
strong_image = strong_color_noise(image)
if random.random() < 0.5:
gauss_blur = transforms.GaussianBlur(3,sigma=(0.1,2.0))
strong_image = gauss_blur(strong_image)
weak_image = norm(to_tensor(weak_image))
strong_image = norm(to_tensor(strong_image))
trimap = to_tensor(trimap)
return weak_image, strong_image, trimap
elif flag == "val":
val_image = resize(image)
val_trimap = resize(trimap)
val_image = norm(to_tensor(val_image))
val_trimap = to_tensor(val_trimap)
return val_image, val_trimap
@staticmethod
def mask_blend(trimap_image):
trimap_image[trimap_image == (2.0 / 255)] = 0.0
trimap_image[trimap_image == (3.0 / 255)] = 0.0
trimap_image[trimap_image == (1.0 / 255)] = 1.0
return trimap_image
def __len__(self):
return len(self.class_list)
def __getitem__(self, idx):
current_image, current_trimap = self.data[idx]
if self.unlabaled_flag:
dataset_image, dataset_image2, trimap_image = self.augmentation(current_image, current_trimap, self.img_size, "train")
if self.class_list[idx][0] not in self.labelled_class_names:
trimap_image = torch.full(trimap_image.size(), -1)
else:
trimap_image = self.mask_blend(trimap_image)
return dataset_image, dataset_image2, trimap_image
else:
dataset_image, trimap_image = self.augmentation(current_image, current_trimap, self.img_size, "val")
trimap_image = self.mask_blend(trimap_image)
return dataset_image, trimap_image