How exactly to use torch.compile()

Hi, I have recently started using Pytorch and I have stumbled into a problem. So i have a model that is already fine-tuned and I upgraded my torch version to 2.x. Now, I am looking to add torch.compile() in the codebase but I only want to add compile during the inference mode.

the model that is being used is already trained and in production and would just want to add the compile method during the inference mode to leverage the Pytorch 2.x magical superpowers. So, I wanna clear a couple of doubts here.

  1. is torch.compile meant to be used for training loops to reduce the training time or it can be used directly in inference mode?
  2. autograd calculates gradients in the backward pass and during the inference mode, we dont calculate gradients so if we do use torch.compile during inference mode how does it work in the backend?

Thanks for your help!

  1. torch compile can also run in inference mode
  2. we won’t try to trace and compile the graph used for backward in the inference case then