Feed video frame predicted by model at time (t) as input to model at time (t+1) using pytorch dataloader

I’m doing video prediction frame prediction using LSTM. Here is my dataloader, I want during testing the _get_auxillary_label function to take the output (i.frame prediction) from the model at time step (t) as input for time step (t+1).

Is possible to do this? Or perhaps there’s a better way of doing this?

Here is my dataloader code

 class TimeSeriesImagesSegmentation(Dataset):
    """
    Returns sequence of images generated from video.
    Trains on the x dataset H x W images
    Used by LSTM 
    """
    def __init__(self,  original_image_path:str,  
                        segmentation_image_path:str, 
                        split:str='train', 
                        image_dimension:int=32, 
                        total_original_images:int=200, 
                        total_segemented_images:int=100,
                        threshold:float=0.5,
                        lag:int=1,
                        extentions:list=['.jpg','.jpg']):
        """
        Data loader for training on Shan-Vese generated images
        using full images. 
        """
        self.original_image_path     = original_image_path
        self.segmentation_image_path = segmentation_image_path

        if total_original_images  is None:
            # use all the images in the directory
             self.total_original_images  = len(os.listdir(original_image_path))
        else:
            self.total_original_images   = total_original_images

        self.total_segemented_images = total_segemented_images
        self.image_dimension         = image_dimension
        self.split                   = split
        self.extentions              = extentions
        self.threshold               = threshold
        self.lag                     = lag
        self.original_images, self.labels, self.names = self.get_file_names()   
        
    def __len__(self):
        return self.total_original_images * ( self.total_segemented_images-self.lag)


    def _transforms(self, image, input_label, target_label):
        
        resize        = torchvision.transforms.Resize(size=(self.image_dimension, self.image_dimension))
        image         = resize(image)
        input_label   = resize(input_label)
        target_label  = resize(target_label)

        if self.split:
            # Horizontal flip
            if np.random.random() > 0.5:
                image = TF.hflip(image)
                input_label = TF.hflip(input_label)
                target_label = TF.hflip(target_label)

            #  Random vertical flipping
            if np.random.random() > 0.5:
                image = TF.vflip(image)
                input_label = TF.vflip(input_label)
                target_label = TF.vflip(target_label)

        # transform to tensor
        image        = TF.to_tensor(image) 
        input_label  = TF.to_tensor(input_label)
        target_label = TF.to_tensor(target_label) 

        # form image classes (i.e. background vs foreground)
        input_label[input_label>=self.threshold]   = 1
        input_label[input_label< self.threshold]   = 0
        target_label[target_label>=self.threshold] = 1
        target_label[target_label< self.threshold] = 0

        if self.split=='train':
            # stack the images together (input + label)
            stacked_images = torch.cat((image, input_label), dim=0) 
        else:

            # stack the images together (input + predicted label)
            stacked_images = torch.cat((image, self.auxillary_label), dim=0) 

        # count and get the number of foreground pixels
        count_foreground_pixels =  (torch.nonzero(input_label).size(0) / (self.image_dimension*self.image_dimension))

        return stacked_images, target_label, count_foreground_pixels

    def _get_auxillary_label(self, auxillary_label):
        """
        get label from model output
        """
        self.auxillary_label = auxillary_label

    def get_file_names(self):
        """
        get the file name of the input and target images.
        """
        original_images, labels , names = [], [], []
        names = []
        # original_image_index = np.arange(1, self.total_original_images+1)
        original_image_index = np.arange(1, len(os.listdir(self.original_image_path))+1)
        print(original_image_index)
        if self.split=='train':
            np.random.shuffle(original_image_index)
            
        original_image_index=original_image_index[:self.total_original_images]
        for original_image in original_image_index: # for each original image
            for segmentation_image in np.arange(1, self.total_segemented_images+1): # for each segmentation image

                image_name_original      =  str(original_image) +                                      self.extentions[0]
                image_name_segementation =  str(original_image) + '_' +  str(segmentation_image)   +   self.extentions[1]

                # store the names of the input images
                original_images.append(image_name_original)
                labels.append(image_name_segementation)
                names.append(self.original_image_path     + '/' + image_name_original)
        return original_images, labels, names

    def __getitem__(self, index):
        # shift the data to form a sequence. (inputs(t), labels(t+1))
        image        = Image.open(self.original_image_path     + '/' + self.original_images[index]).convert('L')  # inputs
        input_label  = Image.open(self.segmentation_image_path + '/' + self.labels[index])           # labels used as inputs
        target_label = Image.open(self.segmentation_image_path + '/' + self.labels[index + self.lag])           # labels used as target
        print(index, index + self.lag)
        # apply data augmentation if train and transformations
        stacked_images, labels, count_foreground_pixels  = self._transforms(image, input_label, target_label)
      
        # get image name
        image_name = str(self.names[index].split("/")[-1])
        return stacked_images, labels, image_name