State-of-the-art for torchvision datasets/transforms/models design

Hi, I’m working on developing a library similar to torchvision for geospatial data. Think multi-spectral satellite imagery datasets/transforms/pre-trained models.

When perusing the torchvision source code, I noticed that there are a lot of inconsistencies due to the long history of development and changing needs of users. For example:

  • Datasets return images and targets, or images and masks and classes, or images and dicts with bounding boxes
  • Datasets accept transform, target_transform, and/or transforms
  • Transforms support PIL Images and/or torch Tensors
  • Transforms subclass object, or nn.Module (for TorchScript support)

My question is, if you were going to write a library like torchvision from scratch, which of these are state-of-the-art best practices, and which of these are just leftover for backwards compatibility? Is best practice to write all new transforms as a subclass of nn.Module, or is there some advantage to more generic object transforms? Since I’m working with multi-spectral imagery, PIL Images won’t work at all, so I’ll likely write all transforms for torch Tensors. If I want consistency between datasets and transforms for various tasks (object detection, instance segmentation, etc.), should all datasets return (and all transforms accept) dicts with keys for possible components (image, mask, bounding boxes, class labels, etc.)?

I’m also curious what other issues the torchvision developers have faced over the years and how they solved them. Would love to meet over a video call.

2 Likes

As for the transforms, we faced similar issues for n-dimensional arrays and that’s why we started kornia.org and subsequently kornia.augmentation submodule. We are now including more advanced augmentations and specialised containers to handle other types of data like masks, key points, etc. See it here: https://app.reviewnb.com/kornia/tutorials/pull/12/

1 Like

Hi Adam,

This is a great question. Indeed, torchvision has cumulated a lot of tech debt over time, which is hard to change due to BC considerations, but which we would gladly change otherwise.

We have initiated some discussions around things we would love to change in torchvision if we didn’t have any BC considerations, we will be posting the summary of the discussion on GitHub once it’s fleshed out, but here are some specific points to your questions:

Datasets

We are in the process of re-designing all of our datasets so that they are built using DataPipes. This will allow them to be used to stream data from both disk but also the internet, like what WebDataset does.

In this process, we will also be making more consistent the return types of datasets – it will probably be a dict everywhere, which contains fields for the image / label / etc (but this is still being fleshed out and @pmeier is actively working on the design)

In this context, a transform will not be part of the dataset anymore, but will just be a .map transform on top of the DataPipe.

Transforms

PIL and Tensors

We started with PIL because it was widely available in Python and was kind of the standard for image processing. With the push for sharing the same transform implementation between Python and C++, we enabled Tensor Transforms, which also give other benefits such as GPU and autograd.

In my view, the PIL-based transforms are mostly legacy now, but we won’t be deprecating those until we have improved all the gaps that there are still compared to the Tensor Transforms (speed and antialising for resize are two things we are improving in the upcoming release)…

About making the transforms be nn.Module vs plain object, this was to facilitate torchscript support.
That being said, I’m not super happy with the current class API for transforms, as it doesn’t allow for a nice way of enabling joint transforms for boxes / images / etc, so a class-based API should be rethink. For now only the functional API for the transforms is powerful enough.

I think it would be interesting to consider that the concept of a “model” can involve everything from data decoding up to the final model output, which means that transforms could be part of the model (and thus are analogous to nn.Module such as nn.Conv2d). The current difficulty about this is how to make it efficient, without also making the dataloader part of the model.

For a few more specific questions:

should all datasets return (and all transforms accept) dicts with keys for possible components

Yes, I think that’s what we would be doing for the new dataset implementations. For the transforms, I don’t yet know, but a redesign would be good.

Happy to chat over a video call, ping me on slack and we can continue the discussion.

4 Likes

Hi Adam,

Whilst working on deep learning tasks, I would much prefer to decouple those three components as much as possible. So, for each of those tasks:

  • Datasets, to read images and labels and convert to the proper dtype only.
  • Transforms, to perform transformation, but with optimized batch transformation features. Also, it should be painless to transform common labels (e.g. mask, bounding box, keypoints).
  • Models, to take in images and output the prediction, and things like normalization shall happen inside models than in augmentations in order to be more “end-to-end”.

Kornia augmentation offered solutions for the better decoupling.

Transforms

Firstly, batch processing is the basis of kornia transformation, which takes the advantage of PyTorch itself (tensor-based computation, nn.Module implementation). Recently, we introduced AugmentationSequential to offer syntax sugars for image augmentations. For example:

aug_list = AugmentationSequential(
    K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
    K.RandomAffine(360, [0.1, 0.1], [0.7, 1.2], [30., 50.], p=1.0),
    K.RandomPerspective(0.5, p=1.0),
    data_keys=["input", "bbox", "keypoints", "mask"],  # Just to define the future input here.
    return_transform=False,
    same_on_batch=False,
)
# forward the operation
out_tensors = aug_list(img_tensor, bbox, keypoints, mask)
# Inverse the operation
out_tensor_inv = aug_list.inverse(*out_tensor)

Then the results are like following (left->right: original, transformed, inversed back):

Datasets and Models

As you might aware that I have skipped the datasets part in the first :smiley:, since we have just introduced batch augmentation. Unlike the traditional augmentation strategy (to include transformations in Datasets implementation), we would recommend to put the image augmentation computations into the next stage. For example:

# Define the augmentation
augment = AugmentationSequential(
    K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
    K.RandomAffine(360, [0.1, 0.1], [0.7, 1.2], [30., 50.], p=1.0),
    data_keys=["input", "bbox"]
)
# Define the model
model = torchvision.models.resnet50()
# Define a more "end-to-end" model
model = nn.Sequential(
    K.Normalize([0.5, 0.5, 0.5], [0.6, 0.6, 0.6]),
    model,
)

As a result, this adaptation can be very convenience for lots of daily routines. Say:

# For training ...
out_tensors = augment(img_tensor, bbox)
model(*out_tensors)
# For testing ...
model(img_tensor)

Save and load without thinking about preprocessing:

model.save("model.pt")
model = torch.load("model.pt")
# No need to think about normalize, since it has been done already in the saved model.
model(img_tensor)

By far, I think Kornia augmentation has resolved most issues you are facing. Let us know if anything!

2 Likes

Thanks, this is great! I just submitted an application for the PyTorch Slack. Will contact you to set up a meeting with myself and some other folks at Microsoft Research who are working with me on this.

I can’t say I’m completely familiar with DataPipes, but I’ll take a look at that. It sounds like we’re in agreement that Datasets should return dicts and that transforms should accept Tensors (especially for my multi-spectral imagery application). @pmeier I would be very interested to know your thoughts on standard dict key names for this.

1 Like

I’ve never heard of Kornia, this is really interesting. I especially like the ability to invert augmentation sequences. The ability to save augmentation as part of the model is also interesting, and further emphasizes the need to subclass nn.Module so that these models/transforms can be exported to ONNX. I’ll keep this in mind!

1 Like

The augmentation module in kornia it’s just an extension of the package. Our main focus is low level computer vision more for industrial and robotics applications in terms of differentiable tensors. We inherit a lot from OpenCV (and soon more) and in the long term we also target to express vision algorithms in a “symbolic” way as you mention with technologies like onnx or similar.

@adamjstewart

I would be very interested to know your thoughts on standard dict key names for this.

There is currently no consensus on that. We have some ideas though. Maybe I can join the call and we can discuss this there a little.

For anyone following this discussion, TorchGeo has now been made public: GitHub - microsoft/torchgeo: TorchGeo: datasets, transforms, and models for geospatial data

We’re hoping to put out a release sometime in the 1 month timeframe. So far things are mostly in the old style (Dataset instead of DataPipes), but we’re less concerned about backwards compatibility so we can always change this at a later date.