How to start with jit?

I have an RL Python code that I want to speed up with JIT.

I have changed from the class definition (torch.nn.Module) to (torch.jit.ScriptModule) and added the decorator @torch.jit.script_method. I need to rerun the numbers, but my impression is that it speeds up slightly the training.

If I print the layers I can see: (conv2_q1): RecursiveScriptModule(original_name=Conv2d)

What else can I speed up with JIT? Can I set up the training part with JIT?

Also, how does this all tie with torch.jit.trace and torch.jit.script?

It is a beginner question, I am quite new to this possible optimization. Feel free to refer to any training material to understand everything.

Thanks!

The recommended approach is to use torch.compile as TorchScript is in “maintenance” mode and won’t receive any major updated anymore.

1 Like

Thanks for the reply.

In fact, I have realized after posting that torch.compile is better. It speeds up much more the model for training/inference.

However, I was wondering if there are more ways to speed up other than torch.compile(model) with torch.jit or other pytorch methods (without using distributed training).

I have a class Trainer which contains an actor and critic CNN models. Moreover, it contains the training methods. Can I pass this trainer class to the compile? Or just the models themselves?

Could you provide some references?