Torch.jit.trace() only works on example input?

I’ve created a model with a forward function that takes “x” as input (image of size (3,416,416)). I create a trace of the model using: module = torch.jit.trace(model, example_forward_input), then save that model using module.save("model.pt"). Then I load this model trace into an Android application. When I send an input to the model (from the phone) that is identical to the input used as “example_forward_input”, I get the correct result. However, when I use any other input tensor (same shape), I get poor results. Is this supposed to be the behaviour of the trace function? Is there a function that traces a model that can generalize to any inputs? Any guidance would be much appreciated.

For some more detail: This is a YOLOv3 based model that involves detection and classification. The classification with different inputs into the traced model gives similar results to the same inputs in the model. However, the detection locations differ (in w/h especially) when running an input that was not used as an example through the traced model.

EDIT: I’m guessing this is due to the fact that my forward module uses control-flow that is dependent on the input, as outlined here. However, when I try to convert the model to a script module, as outlined on that same page. I get the following error: raise NotSupportedError(ctx_range, _vararg_kwarg_err) torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "C:\Users\isaac\Anaconda3\envs\SNProject\lib\site-packages\torch\nn\modules\module.py", line 85 def forward(self, *input): ~~~~~~ <--- HERE print(torch.__version__) r"""Defines the computation performed at every call. As you can see this is coming from the torch library itself. Any suggestions on how to proceed?

1 Like

You are right about the input-dependent control flow requiring torch.jit.script instead of torch.jit.trace. Can you link the YoloV3 implementation you’re using so we can reproduce this error?

I solved the problem. I ended up using torch.jit.trace(), but then having my YOLOLayer inherit a ScriptModule: class YOLOLayer(torch.jit.ScriptModule):, and made my forward method:

@torch.jit.script_method def forward(self, x, targets=torch.tensor([]), img_dim=torch.tensor(416)):
and helper method
@torch.jit.script_method def compute_grid_offsets(self, grid_size): decorated with @torch.jit.script_method. After I did this I just went line-by-line fixing any errors that appeared due to incompatibility with the scripting.

1 Like

Good to hear you fixed it! We changed the API to TorchScript in PyTorch 1.2 to make it easier to use (i.e. you no longer need to change your model to inherit from ScriptModule instead of nn.Module and you don’t need @script_method), you can read more about it here. But this is just sugar over the same thing you’re already doing, so if you already have it working you don’t need to change anything.