I’m doing video prediction frame prediction using LSTM. Here is my dataloader, I want during testing the _get_auxillary_label function to take the output (i.frame prediction) from the model at time step (t) as input for time step (t+1).
Is possible to do this? Or perhaps there’s a better way of doing this?
Here is my dataloader code
class TimeSeriesImagesSegmentation(Dataset):
"""
Returns sequence of images generated from video.
Trains on the x dataset H x W images
Used by LSTM
"""
def __init__(self, original_image_path:str,
segmentation_image_path:str,
split:str='train',
image_dimension:int=32,
total_original_images:int=200,
total_segemented_images:int=100,
threshold:float=0.5,
lag:int=1,
extentions:list=['.jpg','.jpg']):
"""
Data loader for training on Shan-Vese generated images
using full images.
"""
self.original_image_path = original_image_path
self.segmentation_image_path = segmentation_image_path
if total_original_images is None:
# use all the images in the directory
self.total_original_images = len(os.listdir(original_image_path))
else:
self.total_original_images = total_original_images
self.total_segemented_images = total_segemented_images
self.image_dimension = image_dimension
self.split = split
self.extentions = extentions
self.threshold = threshold
self.lag = lag
self.original_images, self.labels, self.names = self.get_file_names()
def __len__(self):
return self.total_original_images * ( self.total_segemented_images-self.lag)
def _transforms(self, image, input_label, target_label):
resize = torchvision.transforms.Resize(size=(self.image_dimension, self.image_dimension))
image = resize(image)
input_label = resize(input_label)
target_label = resize(target_label)
if self.split:
# Horizontal flip
if np.random.random() > 0.5:
image = TF.hflip(image)
input_label = TF.hflip(input_label)
target_label = TF.hflip(target_label)
# Random vertical flipping
if np.random.random() > 0.5:
image = TF.vflip(image)
input_label = TF.vflip(input_label)
target_label = TF.vflip(target_label)
# transform to tensor
image = TF.to_tensor(image)
input_label = TF.to_tensor(input_label)
target_label = TF.to_tensor(target_label)
# form image classes (i.e. background vs foreground)
input_label[input_label>=self.threshold] = 1
input_label[input_label< self.threshold] = 0
target_label[target_label>=self.threshold] = 1
target_label[target_label< self.threshold] = 0
if self.split=='train':
# stack the images together (input + label)
stacked_images = torch.cat((image, input_label), dim=0)
else:
# stack the images together (input + predicted label)
stacked_images = torch.cat((image, self.auxillary_label), dim=0)
# count and get the number of foreground pixels
count_foreground_pixels = (torch.nonzero(input_label).size(0) / (self.image_dimension*self.image_dimension))
return stacked_images, target_label, count_foreground_pixels
def _get_auxillary_label(self, auxillary_label):
"""
get label from model output
"""
self.auxillary_label = auxillary_label
def get_file_names(self):
"""
get the file name of the input and target images.
"""
original_images, labels , names = [], [], []
names = []
# original_image_index = np.arange(1, self.total_original_images+1)
original_image_index = np.arange(1, len(os.listdir(self.original_image_path))+1)
print(original_image_index)
if self.split=='train':
np.random.shuffle(original_image_index)
original_image_index=original_image_index[:self.total_original_images]
for original_image in original_image_index: # for each original image
for segmentation_image in np.arange(1, self.total_segemented_images+1): # for each segmentation image
image_name_original = str(original_image) + self.extentions[0]
image_name_segementation = str(original_image) + '_' + str(segmentation_image) + self.extentions[1]
# store the names of the input images
original_images.append(image_name_original)
labels.append(image_name_segementation)
names.append(self.original_image_path + '/' + image_name_original)
return original_images, labels, names
def __getitem__(self, index):
# shift the data to form a sequence. (inputs(t), labels(t+1))
image = Image.open(self.original_image_path + '/' + self.original_images[index]).convert('L') # inputs
input_label = Image.open(self.segmentation_image_path + '/' + self.labels[index]) # labels used as inputs
target_label = Image.open(self.segmentation_image_path + '/' + self.labels[index + self.lag]) # labels used as target
print(index, index + self.lag)
# apply data augmentation if train and transformations
stacked_images, labels, count_foreground_pixels = self._transforms(image, input_label, target_label)
# get image name
image_name = str(self.names[index].split("/")[-1])
return stacked_images, labels, image_name