State-of-the-art for torchvision datasets/transforms/models design

Hi Adam,

Whilst working on deep learning tasks, I would much prefer to decouple those three components as much as possible. So, for each of those tasks:

  • Datasets, to read images and labels and convert to the proper dtype only.
  • Transforms, to perform transformation, but with optimized batch transformation features. Also, it should be painless to transform common labels (e.g. mask, bounding box, keypoints).
  • Models, to take in images and output the prediction, and things like normalization shall happen inside models than in augmentations in order to be more “end-to-end”.

Kornia augmentation offered solutions for the better decoupling.

Transforms

Firstly, batch processing is the basis of kornia transformation, which takes the advantage of PyTorch itself (tensor-based computation, nn.Module implementation). Recently, we introduced AugmentationSequential to offer syntax sugars for image augmentations. For example:

aug_list = AugmentationSequential(
    K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
    K.RandomAffine(360, [0.1, 0.1], [0.7, 1.2], [30., 50.], p=1.0),
    K.RandomPerspective(0.5, p=1.0),
    data_keys=["input", "bbox", "keypoints", "mask"],  # Just to define the future input here.
    return_transform=False,
    same_on_batch=False,
)
# forward the operation
out_tensors = aug_list(img_tensor, bbox, keypoints, mask)
# Inverse the operation
out_tensor_inv = aug_list.inverse(*out_tensor)

Then the results are like following (left->right: original, transformed, inversed back):

Datasets and Models

As you might aware that I have skipped the datasets part in the first :smiley:, since we have just introduced batch augmentation. Unlike the traditional augmentation strategy (to include transformations in Datasets implementation), we would recommend to put the image augmentation computations into the next stage. For example:

# Define the augmentation
augment = AugmentationSequential(
    K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
    K.RandomAffine(360, [0.1, 0.1], [0.7, 1.2], [30., 50.], p=1.0),
    data_keys=["input", "bbox"]
)
# Define the model
model = torchvision.models.resnet50()
# Define a more "end-to-end" model
model = nn.Sequential(
    K.Normalize([0.5, 0.5, 0.5], [0.6, 0.6, 0.6]),
    model,
)

As a result, this adaptation can be very convenience for lots of daily routines. Say:

# For training ...
out_tensors = augment(img_tensor, bbox)
model(*out_tensors)
# For testing ...
model(img_tensor)

Save and load without thinking about preprocessing:

model.save("model.pt")
model = torch.load("model.pt")
# No need to think about normalize, since it has been done already in the saved model.
model(img_tensor)

By far, I think Kornia augmentation has resolved most issues you are facing. Let us know if anything!

2 Likes