Can PyTorch transforms be added to a Torchscript model?

I have a basic image classifier, and I want to package this up as a Torchscript model. I plan to deploy this in a Haskell application, using the C++ API as a foreign function interface. This works fine, but I would also like to add the necessary transforms to this model. My model is a fine-tuned Resnet model, and when performing training I use the provided transform to resize/center crop etc. It seems unfortunate to have to duplicate this into my inference application manually, and I was hoping I could just add it to my Torchscript model.

My naive attempt is:

class E2EClumpDetector(pl.LightningModule):
    def __init__(self, clump_detector):
        super().__init__()
        self.clump_detector = clump_detector

    def forward(self, batch):
        return torch.nn.functional.sigmoid(self.clump_detector(self.clump_detector.transform(batch)))

But this gives me

RuntimeError: 
Module 'ClumpDetection' has no attribute 'transform' (This attribute exists on the Python module, but we failed to convert Python type: 'torchvision.transforms.transforms.Compose' to a TorchScript type. Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type Compose.. Its type was inferred; try adding a type annotation for the attribute.):
  File "/home/ollie/work/clump-detection/clump_detection.py", line 76
    def forward(self, batch):
        return torch.nn.functional.sigmoid(self.clump_detector(self.clump_detector.transform(batch)))

So I don’t think this works. Does anyone have any advice that doesn’t require me to duplicate the series of transforms?

Ah, I’ve just seen Transforming and augmenting images — Torchvision 0.16 documentation! Maybe it’s really as simple as swapping out Compose for Sequential.

Edit: yes, it was that simple! Just should have RTFM first :smile: