Updating the transform class variable of a Dataset instance

It seems like a bug to me, so I am not sure if this is correct place for this question but here it is.

I have declared a custom dataset something like:

from torch.utils.data import Dataset

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

Now I declare an instance of this class:


dataset = FaceLandmarksDataset('file_path.csv', 'img_dir')
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train.transform = transform(useAugmentations=True)
val.transform = transform(useAugmentations=False)

The code works but the transformation is not applied. However, if I declare with the transformation from the beginning like the following then it works fine.

dataset = FaceLandmarksDataset('file_path.csv', 'img_dir', transform(useAugmentations=True))
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])

So how can I override a instance variable for the class of type Dataset and why is it on updating like it should?

So random_split returns to Subset instances that themselves refer to the (single!) DataSet instance.

The modern workflow for this is to override the collate function in the dataloader (calling default collate and then optionally augmenting the dataset).
Or you could make full instances of your dataset. In the end, it just is running idx through an index indirection generated randperm and adjusting the length.

Best regards


Thanks! That helped a lot!

1 Like