Train/validate on datasets of different size


(Andreas Leurs) #1

I have a training set and a validation set with different size (size trainset > size val set). I want to go through the whole training set and at the same time through the val set. When it reaches the last element of the val set it should just start at the first element of the val set and repeat that until the last element of the train set is reached.

When I try this

for iteration, (batch, batch2) in enumerate(zip(training_data_loader, val_data_loader), 0):

it only goes until the last element of the smaller data_loader ist reached and the ends.

So I think best should do with something like this:

for iteration, batch in enumerate(training_data_loader, 0):
    input = Variable(batch[0])
    target = next(iter(val_data_loader[iteration]))

My getitem function looks like that:

def __getitem__(self, index):
    if index == len(self):
        index = index % len(self)
    input = load_img(self.image_filenames[index])    
    target = load_img(self.image_filenames_target[index])
        
    if self.input_transform:
        input = input.resize((224, 224))
        input = self.input_transform(input)
    if self.target_transform:
        target = target.resize((224, 224))
        target = self.target_transform(target)
        
    return input, target

But I get this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-28-6756247aaed3> in <module>()
 82     input = Variable(batch[0])
 83     print('input.size():', input.size())
---> 84     target = next(iter(val_data_loader[iteration]))
 85     print('target.size():', target.size())
 86     print('iteration', iteration, 'done')

TypeError: 'DataLoader' object does not support indexing

Do you have an idea how I can get it right?


(Andreas Leurs) #2

I solved the issue by also loading the val data set in the train dataset:

class DatasetFromFolder(data.Dataset):
def __init__(self, image_dir_input, image_dir_target, image_dir_input_val, image_dir_target_val, input_transform=None, target_transform=None):
    super(DatasetFromFolder, self).__init__()
    self.image_filenames = [join(image_dir_input, x) for x in listdir(image_dir_input) if is_image_file(x)]
    self.image_filenames.sort()
    self.image_filenames_target = [join(image_dir_target, x) for x in listdir(image_dir_target) if is_normal_file(x)]
    self.image_filenames_target.sort()

    self.image_filenames_val = [join(image_dir_input_val, x) for x in listdir(image_dir_input_val) if is_image_file(x)]
    self.image_filenames_val.sort()
    self.image_filenames_target_val = [join(image_dir_target_val, x) for x in listdir(image_dir_target_val) if is_normal_file(x)]
    self.image_filenames_target_val.sort()
    
    self.input_transform = input_transform
    self.target_transform = target_transform

def __getitem__(self, index):
    input = load_img(self.image_filenames[index])
    target = load_img(self.image_filenames_target[index])
    input_val = load_img(self.image_filenames_val[index % len(self.image_filenames_val)])    
    target_val = load_img(self.image_filenames_target_val[index % len(self.image_filenames_target_val)])

    if self.input_transform:
        #input = input.resize((320, 240))
        input = input.resize((224, 224))
        input = self.input_transform(input)
        input_val = input_val.resize((224, 224))
        input_val = self.input_transform(input_val)
    if self.target_transform:
        #target = target.resize((320, 240))
        target = target.resize((224, 224))
        target = self.target_transform(target)
        target_val = target_val.resize((224, 224))
        target_val = self.target_transform(target_val)
    return input, target, input_val, target_val

def __len__(self):
    return len(self.image_filenames)

And in the getitem function I get the right element with a modulo operation.