Mixed precision inference on ARM servers


My usecase is to take a FP32 pre-trained PyTorch model, convert it to FP16 (both weights and computation that is amenable to Fp16 computation) and then trace the model. Later, I will read this model in TVM (a deep learning compiler) and use it to generate code for ARM servers. ARM servers have instructions to speed up FP16 computation. Please let me know if this is possible today.

Note that I need to use Torchscript (as TVM input is traced model). Also, the goal here is get speedup so my hope is that the resulting traced model has some operators (like conv2d, dense) whose inputs dtypes are in FP16.


autocast is not supported yet in scripted models, but in case your model doesn’t use any data dependent control flow, you could try to trace the model inside the autocast context.
I don’t know, how the TorchScript model would be provided to TVM and what kind of instructions ARM servers are using.

I see. Thanks for the reply @ptrblck
I will read more about autocast tomorrow and will try tracing it.

Meanwhile, I also used the following script (using half) which gave me a fp16 graph

import torchvision
import torch
import tvm
from tvm import relay

is_fp16 = True

model = torchvision.models.resnet18(pretrained=True)
x = torch.rand(1, 3, 224, 224)
if is_fp16:
    model = model.cuda().half()
    x = x.cuda().half()
    scripted_model = torch.jit.trace(model, (x))
    scripted_model = torch.jit.trace(model, (x))

input_name = "input0"
input_shape = (1, 3, 224, 224)
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
mod = relay.transform.InferType()(mod)

My usecase is inference. The above code converted the whole model to fp16. But I am worried that it might be doing to much of fp16 and may need to go back to fp32 for some operators like sum or batch_norm.

That’s a reasonable concern and is the reason for the implementation of the amp.autocast, which would use FP32 precision, where necessary. For your use case you could certainly try to call .half() on the entire model and check the results manually.

Makes sense. So, I followed up on autocast. This is what my code looks like

model = torchvision.models.resnet18()

x = torch.rand(1, 3, 224, 224).cuda()
model = model.cuda()
with torch.cuda.amp.autocast(enabled=True):
    scripted_model = torch.jit.trace(model, (x))

It failed with the following error - RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

I then added no_grad

model = torchvision.models.resnet18()

x = torch.rand(1, 3, 224, 224).cuda()
model = model.cuda()
with torch.cuda.amp.autocast(enabled=True):
    with torch.no_grad():
        scripted_model = torch.jit.trace(model, (x))

But this also failed with some different error

torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!

I am not sure if my usage is correct.

Thanks for the update. I guess that some (new) graph passes might be running into these issues, but I’m unsure which optimizations are performed. Until scripting + amp is fully supported, you could manually cast the model and check the outputs.

Thanks @ptrblck for the prompt responses. You have been really helpful. I will manually cast the models till then.