Turning PYTORCH_JIT mode on/off dynamically

Hi,
We are planning to use PyTorch JIT in production and our top-level modules use JIT Script annotations. We require the JIT Scripting to turn off during training and we need to switch it back on for model export (with inference). However PYTORCH_JIT is an environment variable that is set statically.

  1. How can we achieve this without explicitly using PYTORCH_JIT=0 statically while launching the process?
  2. While exporting the model, we perform a correctness check and need to turn it off/on dynamically for the 2 models (with and without JIT) to be loaded. How can we achieve this?

Thanks!

@apaszke told me that PYTORCH_JIT=0 is only meant for debugging use. In your case, I would argue that you should do jit.script before exporting, where you can also compare both versions for correctness.

I assume you mean using the module normally during training and then call torch.jit.trace at export time to create the JIT version. While this works for fully traceable models, for my use case I have mixed tracing and TorchScript, and I don’t see any way to disable the TorchScript annotations, the JIT is used for script functions/methods even if torch.jit.trace was never called. Is there any other way to achieve this?

1 Like

I have these methods:

def set_jit_enabled(enabled: bool):
    """ Enables/disables JIT """
    if torch.__version__ < "1.7":
        torch.jit._enabled = enabled
    else:
        if enabled:
            torch.jit._state.enable()
        else:
            torch.jit._state.disable()


def jit_enabled():
    """ Returns whether JIT is enabled """
    if torch.__version__ < "1.7":
        return torch.jit._enabled
    else:
        return torch.jit._state._enabled.enabled

When training, I call set_jit_enabled(false) before instantiating modules that have JIT annotations. I haven’t found the proper way of doing this, but this hack works just fine.