Applying torchvision.transforms to local variables

Newbie here.

I have a dataset of images that I want to train a model on. I want to train first on the “base” images and then applying augmentations to the images in a GRADUAL way (more and more aggressive) and keep training.

So, I use:

base_transform = transforms.Compose([
        transforms.Resize((224,224)),       # resize to 224x224
        transforms.ToTensor()
   ])

base_data = datasets.ImageFolder(os.path.join(DATASET_PATH, 'train'), transform=base_transform)
base_loader = DataLoader(base_data, batch_size=len(train_data), shuffle=True)

#This part takes 2.5 min, but should be done only ONCE
X_train_base,y_train_base = next(iter(base_loader)) #X_train_base and y_train_base are tensors on the CPU

#Push to the GPU
X_train_base = X_train_base.to(device)
y_train_base = y_train_base.to(device)

to create the base variables (and they are now tensors on the GPU). No problem here.

Now, I want to create the augmentation transformation:

aug_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),  # reverse 50% of images
        transforms.RandomAffine(degrees = (0,45), translate=(0.3,0.3), scale=(0.8,1.2), shear=None, resample=False, fillcolor=0),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

How do I now apply it to my base variables? I could only figure out how to apply them to images files on the disk.

In newer torchvision versions you should be able to apply the transformation directly on the batch of tensors inside the DataLoader loop:

aug_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),  # reverse 50% of images
        transforms.RandomAffine(degrees = (0,45), translate=(0.3,0.3), scale=(0.8,1.2), shear=None, resample=False, fillcolor=0),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

x = torch.randn(10, 3, 224, 224)
out = aug_transform(x)
1 Like

Thank you! That is exactly what I was looking for.