Creating a torchscript wrapper?

I’m aware that torchscript does not support mixed serialization but is there anyway to easily save a torchscript file with a preprocessing wrapper? The actual input to my models are vectorized features derived from a text file and the features chosen are model hyperparameters, as such I wanted to hide that process under the hood allowing users to pass a raw text file and have the preprocessing steps wrapped around torchscript object but I can’t see how to do this without mixed serialization?

Is there a work around?

Just make a class wrapper. First, make all preprocessing functions torchscript friendly, then you can create a class like:

class Wrapper(nn.Module):
    def __init__(self, preprocess_fn, model):
       self.preprocess_fn = preprocess_fn
       self.model = model

    def forward(self, x): # remember, by default, x is expected to be a torch.Tensor. 
        return self.model(self.preprocess_fn(x))

wrapper = Wrapper(your_preprocess_function, your_model)
wrapper_ts = torch.jit.script(wrapper)

And that’s all!
I did the same approach but to integrate some post processing steps without needing to deploy extra python files.

@vferrer Hello! Your reply is very helpful, but I need some clarifications. What do you mean by “functions torchscript friendly”? What if I need to have an audio file as input and the preprocessing function is the one that extracts the feature tensor?

@Sele With “functions torchscript friendly” I mean python functions that can be scripted: ts_fn = torch.jit.script(fn) as they will compiled to torchscript (they are inside the forward pass).