Torch.jit.trace() fails if model forward function takes more than one argument

My PyTorch model takes more than one argument in forward() function. That is apart from data input, it takes additional arguments for some computation. When I execute

torch.jit.trace(loadedmodel, example_inputs = data)

I get an error as forward() missing 4 required positional arguments: ‘l1’, ‘l2’, ‘criterion1’, and ‘criterion2’. Since I have already done training with this model, I loaded this model in variable loadedmodel and I am not able to figure out how can I send loadedmodel to torch.jit.trace function which expects model to take only data input in forward function. Please Help

You just need to specify example_inputs as a tuple:

torch.jit.trace(loadedmodel, example_inputs = (x, l1, l2, criterion1, criterion2))