Does AMP work in TorchScript?

Does AMP work in TorchScript?

1 Like

If you could trace the model, all operations should be recorded (thus also the transformations) and amp should work. This would also mean that you don’t need to run the traced model in an autocast region anymore.

However, scripted models are currently still WIP.

4 Likes

Hey @ptrblck, what do you mean by scripted models above? Would we need to trace the model within an autocast context block?

By “scripted model” I meant a model, which was scripted via torch.jit.script, while a traced model would be created via torch.jit.trace.
As said before, tracing your model might work, if you don’t have any data-dependent code paths. I.e. if your model is static and tracing it works fine, you should be able to use amp. However, this is not tested properly and I would see it as an experimental workflow. Scripting a model with amp is not implemented yet.

Is the Script model currently usable with AMP? Thanks.

AMP should be working now in scripted models with the nvfuser JIT backend and by enabling it via torch._C._jit_set_autocast_mode(True). You might need to install the latest nightly binary to get the current update of the nvfuser backend.

1 Like

Thanks for the answer. But what is “latest nightly binary”? How can I install it. Thanks

You can pick it in the install page by selecting the “Preview (nightly)” option or my building from source.

1 Like

Is this option only available in the nightly binaries?
In my case:

torch.__version__,torch.version.cuda,torch.backends.cudnn.version()
torch._C._jit_set_autocast_mode(True)

yields:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_17936/3336213180.py in <module>
      1 torch.__version__,torch.version.cuda,torch.backends.cudnn.version()
----> 2 torch._C._jit_set_autocast_mode(True)

AttributeError: module 'torch._C' has no attribute '_jit_set_autocast_mode'

I get an error when using torch.jit.script with autocast, therefore I’m asking.
Thanks!