I’m aware that torchscript does not support mixed serialization but is there anyway to easily save a torchscript file with a preprocessing wrapper? The actual input to my models are vectorized features derived from a text file and the features chosen are model hyperparameters, as such I wanted to hide that process under the hood allowing users to pass a raw text file and have the preprocessing steps wrapped around torchscript object but I can’t see how to do this without mixed serialization?
Just make a class wrapper. First, make all preprocessing functions torchscript friendly, then you can create a class like:
class Wrapper(nn.Module):
def __init__(self, preprocess_fn, model):
self.preprocess_fn = preprocess_fn
self.model = model
def forward(self, x): # remember, by default, x is expected to be a torch.Tensor.
return self.model(self.preprocess_fn(x))
wrapper = Wrapper(your_preprocess_function, your_model)
wrapper_ts = torch.jit.script(wrapper)
And that’s all!
I did the same approach but to integrate some post processing steps without needing to deploy extra python files.
@vferrer Hello! Your reply is very helpful, but I need some clarifications. What do you mean by “functions torchscript friendly”? What if I need to have an audio file as input and the preprocessing function is the one that extracts the feature tensor?
@Sele With “functions torchscript friendly” I mean python functions that can be scripted: ts_fn = torch.jit.script(fn) as they will compiled to torchscript (they are inside the forward pass).