Hi,
I am currently implementing a modified Resnet CNN model. The data is in H5F format. I read the entire data into memory and train/test split it. The following is a simple example.
class mydataset(Dataset):
def __init__(self,
input_file,
transforms=None):
super(mydataset, self).__init__()
self.transforms = transforms
self.file = h5py.File(input_file, "r")
self.label = ... # read label to memory
self.img = .... # read img to memory
def __len__(self):
return len(self.label)
def __getitem__(self, idx):
img = self.img[idx]
label = self.label[idx]
if self.transforms is not None:
img = self.transforms(img)
return img, label
# Then I do train & validation split as following
def pytorch_train_test_split(dataset,
batch_size=32,
test_size=0.3,
random_state=4,
shuffle=True):
'''
Split train and test dataset
Arguments:
----------
data (torch.Dataset) : a pytroch dataset
random_state (int) : a random seed
batch_size (int) : a size of each batch
test_size (float) : the size of test dataset
shuffle (bool) : if true, shuffle the dataset
Return:
-------
train dataloader, test dataloader
'''
dataset_size = dataset.__len__()
indices = list(range(dataset_size))
split = int(np.floor(test_size * dataset_size))
if shuffle:
np.random.seed(random_state)
np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]
# Creating data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = DataLoader(dataset, batch_size=batch_size,
sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=512,
sampler=test_sampler)
return train_loader, test_loader
# main()
# apply transform
transforms = transforms.Compose([transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
# read dataset
my_data = mydataset("somefile.h5", transforms)
# generate train and valid data loader
train_loader, valid_loader = pytorch_train_test_split(my_data,
batch_size=64,
test_size=0.25)
However, there is a problem by doing this way. We normally do not validate/test data with RandomFlip (or randomscale or randomcrop…). How can I disable transform on validation data loader? I know that we can write a script to split data into train and validation data first then create two datasets with different transforms, but is there any other way to do this?
Any help would be appreciated.
Thank you.