Dataloader for sequence of images that are not overlapping

Hi Everyone,

I’m doing frame generation. For each images in Dataset/train/ folder (e.g 1.png ) I generated a sequence of 100 images and save all of them into a single Dataset/frames/train/ as (1_1.png…1_100.png), here is an example of my folder structure:

Dataset:
      train:
           1.png
           2.png
           3.png
           .
           .
           N.png
     frames:
           train:
               1_1.png
               1_2.png
               .
               .
               N_100.png

               2_1.png
               2_2.png
               .
               .
               N_100.png

I have created my custom dataloader where I stack the frames generated as channels to form a sequence, but my problem I don’t want to frames from image 2 to overlap with frames from 1 when create a sequence, How can I ensure the different frames don’t overlap?

Here is my custom dataloader:


class LevelSetDataset(Dataset):
    """
    Dataset object for CNN models
    Temporal is defined implicitly 
    as the number of channels
    example: 
        - X dimension
            [H, W, C=number_of_timestap(t)]
        - Y dimension
            [W, W, C =(t+1)]
    """
    def __init__(self, input_image_path:str,
                target_image_path:str,
                threshold:float=0.5,
                num_input_steps:int=3,
                num_future_steps:int=1,
                image_dimension:int=32,
                data_transformations=None,
                istraining_mode:bool=True
                ):
        
        self.input_image_path    = input_image_path
        self.target_image_path   = target_image_path
        self.threshold           = threshold
        self.num_input_steps     = num_input_steps
        self.num_future_steps    = num_future_steps
        self.image_dimension     = image_dimension
        self.data_transformations= data_transformations
        self.istraining_mode     = istraining_mode
        
        
        # get a list of input filenames as sort them (e.g. 1.png, 2.png,..,N.png)
        input_image_fp = sorted(glob(os.path.join(self.input_image_path , "*")), 
                                    key=lambda x: int(os.path.basename(x).split('.')[0])
                                                     )
        
        
        # repeat the input image untill it matches the number of segmentation
        # step of the target image
        self.input_image_fp = [i for i in input_image_fp for _ in range(100)]
        
        # get a list of the target filenames and sort them by the first id and second
        # id after the underscore (e.g.  1_1.png, 1_2,..,N_M.png)
        self.target_image_fp= sorted(glob(os.path.join(self.target_image_path , "*")),
                                    key=lambda x: (int(os.path.basename(x).split('_')[0]), 
                                                   int(os.path.basename(x).split('_')[1].split('.')[0]))
                                    )
         
        # check if in training mode
        # to apply transformations
        if (self.data_transformations is None) and (self.istraining_mode):
            self.data_transformations= torchvision.transforms.Compose([
                                            torchvision.transforms.Resize(size=(self.image_dimension,self.image_dimension), 
                                                                                interpolation=Image.BILINEAR),
                                            torchvision.transforms.RandomHorizontalFlip(p=0.5),
                                            torchvision.transforms.RandomVerticalFlip(p=0.5),
                                            torchvision.transforms.ToTensor()
                                                                      ])
            
        if (self.data_transformations is None) and (not self.istraining_mode):
            self.data_transformations== torchvision.transforms.Compose([
                                            torchvision.transforms.Resize(size=(self.image_dimension,self.image_dimension), 
                                                                                interpolation=Image.BILINEAR),
                                            torchvision.transforms.ToTensor()
                                                                      ])
            
            
            
        self.transforms = self.data_transformations    
            
    def _create_binary_mask(self, x):
        x[x>=self.threshold] = 1
        x[x <self.threshold] = 0
        return x
    
    def _stat_norm(self, x):
        norm =torchvision.transforms.Compose([torchvision.transforms.Resize(
            size=(self.image_dimension,self.image_dimension), 
                      interpolation=Image.BILINEAR),
                    torchvision.transforms.ToTensor()])
        return norm(x)
 
    def __len__(self):
        return len(self.target_image_fp) - (self.num_input_steps+self.num_future_steps)

    def __getitem__(self, index):
        X          = torch.zeros((self.image_dimension, self.image_dimension, self.num_input_steps+1))
        for step_idx, step in enumerate(np.arange(index, self.num_input_steps, 1)):
            target_image = Image.open(self.target_image_fp[step+self.num_input_steps+self.num_future_steps-1])
            target_image = self.transforms(target_image)
            target_image = self._create_binary_mask(target_image)
            X[:, :, step_idx] = target_image # (t+1)
           
        input_img  = Image.open(self.input_image_fp[index]).convert('L')
        #         input_img  = self.transforms(input_img) 
        input_img  = self.transforms(input_img)
        X[:, :, 0] = input_img 
        target_image = Image.open(self.target_image_fp[index+self.num_input_steps+self.num_future_steps-1])
        target_image = self.transforms(target_image)
        target_image = self._create_binary_mask(target_image)
        image_name   = self.target_image_fp[index+self.num_input_steps+self.num_future_steps-1].split('/')[-1]
    
        Y = target_image 
        return X, Y, image_name
    

@ptrblck may you please assist?