Apply different custom functions to different subsets

I have a dataset class function:

class MYDataset(Dataset):
    def __init__(self, path):
        df = read_csv(path, header=None, delimiter=r"\s+")
        df = df.iloc[:, 1:-1].values
        self.X = df[:, 1:]
        self.y = df[:, 0]
        self.y = self.y.reshape((len(self.y), 1))

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return [self.X[idx], self.y[idx]]
    
    def get_splits(self, n_test=0.3):
        testing_size = round(n_test * len(self.X))
        train_size = len(self.X) - testing_size
        val_size = round(testing_size /2)
        test_size = testing_size - val_size
        return random_split(self, [train_size, val_size, test_size])
    
def prepare_data(path):
    dataset = MYDataset(path)
    train, val, test = dataset.get_splits()
    train_dl = DataLoader(train, batch_size=64, shuffle=True)
    val_dl = DataLoader(val, batch_size=len(val), shuffle=False)
    test_dl = DataLoader(test, batch_size=len(test), shuffle=False)
    return train_dl, val_dl, test_dl

and I have 2 other custom functions:

  def augment_data(array):
     ...
     return(gen_data)
  def BINNED(dataframe):
     ...
     return(data_ohe, 0,1)

How can I apply the function BINNED on train, val, and test subsets and apply the function augment_data only on ‘train’ subset?

cc @VitalyFedyunin re: Data loader question

You could create three different datasets for training, validation, and testing using the corresponding transformations. Afterwards, you could wrap each of them into a Subset and pass the desired indices to them so that no samples will be repeated.
This approach would need more memory if you are preloading the data and depending on your use case you might want to enable lazy loading.

@ptrblck Thank you for the suggestion. I’m new to pytorch. Do you have an example I can look at?

Something like this should work:

train_dataset = MyDataset(transform=train_transform)
val_dataset = MyDataset(transform=val_transform)
test_dataset = MyDataset(transform=test_transform)

# you can use e.g. train_test_split from sklearn here to create the indices
train_idx, val_idx, test_idx = create_indices() 

train_dataset = Subset(train_dataset, train_idx)
val_dataset = Subset(val_dataset, val_idx)
test_dataset = Subset(test_dataset, test_idx)

train_loader = DataLoader(train_dataset)
val_loader = DataLoader(val_dataset)
test_loader = DataLoader(test_dataset)
1 Like