Speed up Dataloader using the new Torchvision Transforms support for Tensor, batch computation, GPU

Hello there,

According to the following torchvision release transformations can be applied on tensors and batch tensors directly. It says:

torchvision transforms are now inherited from nn.Module and can be torchscripted and applied on torch Tensor inputs as well as on PIL images. They also support Tensors with batch dimension and work seamlessly on CPU/GPU devices

Here a snippet:

import torch
import torchvision.transforms as T

transforms = torch.nn.Sequential(
    T.RandomCrop(224),
    T.RandomHorizontalFlip(p=0.3),
    T.ConvertImageDtype(torch.float),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)

tensor_image = torch.randint(0, 256, size=(3, 256, 256), dtype=torch.uint8)

# works directly on Tensors
out_image1 = transforms(tensor_image)

# on the GPU
out_image1_cuda = transforms(tensor_image.cuda())

# with batches
batched_image = torch.randint(0, 256, size=(4, 3, 256, 256), dtype=torch.uint8)
out_image_batched = transforms(batched_image)

How can we utilize this fact to improve our dataloaders performance.

Loading and Transforming ImageNet images is taking on my computer approx 5s to load a 256-images batch. Given the fact that I will move the input tensors to device anyway why not apply transformations after doing this step?

The typical work flow is to define a transformation composition and set it up for a torch dataset. The composition transforms PIL images to tensors in the final step as following:

import torchvision.transforms as T
transform = T.compose([
                     T.Resize(...),
                     T.Crop(....),
                     T.Flip(...),
                      # And finally
                     T.ToTensor(),
])

This doesn’t make use of the fact that transformations can be applied on device tensors directly.

My thought is to add transformation statement inside the data loading loop as following:

for i,(input,target) in enumerate(dataloader):
      # move data to GPU
      input = input.to(device)
      target = target.to(device)
      # APPLY transformations
      input = transform(input)
      # feed to model
      output = model(input)     

However this approach is not clean and seamless as simply setting a transformation to a dataset. Though I am afraid it is the only possible approach.

I am not sure if this approach will even speed up loading. Will it?

Any thoughts?

4 Likes

Hi! Have you figured this out? I’m also interested on improving computation times for transformations.

Hello! I actually went on and did an experiment on the STL10 dataset. The approach I mentioned here which is to apply transformation inside the training loop works well and speeds transformation (I didn’t manage to find a cleaner way. However, after using it I found it maintainable). You can also set pin_memory=True in data loaders to speed up the process of moving data into the GPU. Finally, if the transformation is not random and fixed you can cache it and not apply it every epoch.

Hi! According to Cifar10 Example, the dataset and dataloader pipeline is:

transform = transforms.Compose(
    [transforms.RandomCrop(224),
     transforms.RandomHorizontalFlip(p=0.3)]
)

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

The transform process is done inside torchvision.datasets.CIFAR10, so one image will have one transforms() call, and later dataloader will concat them.

Do you think it is a good practice to do

# pass this transform function to torchvision.datasets.CIFAR10, instead of transform.Compose()
transform = torch.nn.Sequential(
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.3),
).to('cuda')

# edit inside torchvision.datasets.CIFAR10
def __getitem__(self, index: int) -> Tuple[Any, Any]:
        ...
        if self.transform is not None:
            img = self.transform(transforms.functional.to_tensor(img))
            
        ...
# Run our code as normal