DeepAR and torch.jit.trace

I have created a simple deepAR model (from gluonts.torch.model.deepar.DeepAREstimator). The model is based on 3 time series with real target values only. I cannot get torch.jit.trace to work with this model. The input is always wrong.
I need to convert my model to Torch Script to load it to java djl.
I an able to do that with simpler models.
Has anybody got deepAR trace to work? An example would be much appreciated.
Thanks!

torch.jit.trace can be silently wrong torch.jit.trace — PyTorch 2.1 documentation

Would suggest you take a look at torch.export.export() from the latest release

I have tried export with latest version of torch (2.1.0). I am getting
RuntimeError: Windows not yet supported for torch.compile.
Thanks!

Yeah windows support is unlikely to happen anytime soon, I even know some windows developers that are using WSL nowadays

I was able to solve my problem by getting the inputs from the model itself:

predictor = torch.load(‘models/deepAR.pth’)
feat_static_cat = predictor.prediction_net.example_input_array[‘feat_static_cat’]
feat_static_real = predictor.prediction_net.example_input_array[‘feat_static_real’]
past_time_feat = predictor.prediction_net.example_input_array[‘past_time_feat’]
past_target = predictor.prediction_net.example_input_array[‘past_target’]
past_observed_values = predictor.prediction_net.example_input_array[‘past_observed_values’]
future_time_feat = predictor.prediction_net.example_input_array[‘future_time_feat’]

example_inputs = (feat_static_cat, feat_static_real, past_time_feat,
past_target, past_observed_values, future_time_feat)
model = predictor.prediction_net.model
traced_script_module = torch.jit.trace(model, example_inputs)