I am using a model which I saved as torchscript model (torch.jit.script), but every time I load the model it takes over a minute to warm up (profiling/optimization done in the first 2 invocations).
Is there a way to somehow save not just the model but the optimized graph/bytecode (possibly with multiple input shapes)? Also is there an option to cache the graph/bytecode as new input shapes show up?
Is the code stuck in these two forward passes for over a minute?
This sounds quite unexpected. Would you be able to share the model definition so that we could take a look at it?
Some internal utils. are able to cache code-generated kernels, such as the Jiterator, but I don’t know if e.g. torch.compile
is already doing it (I don’t think so, but might be mistaken),
Unfortunately I can’t share my model, but notice the same with a stock model. A code snippet measuring processing time of the densenet121 model:
import torch
import torchvision
from time import time
model = torchvision.models.densenet121(weights=torchvision.models.densenet.DenseNet121_Weights.DEFAULT)
model.eval()
model.cuda()
data = torch.randn(1,3,224,224).cuda()
t = time()
smodel = torch.jit.script(model)
print("scripting: ", time()-t)
def test(model):
with torch.no_grad():
for _ in range(4):
t = time()
out = model(data)
torch.cuda.synchronize()
print("process: ", time()-t)
test(smodel)
torch.jit.save(smodel, "model.pt")
lmodel = torch.jit.load("model.pt")
print("loaded torchscript")
test(lmodel)
t = time()
frozen = torch.jit.freeze(smodel)
print("froze model: ", time()-t)
torch.jit.save(frozen, "fmodel.pt")
fmodel = torch.jit.load("fmodel.pt")
print("loaded frozen model")
test(fmodel)
traced = torch.jit.trace(model, [data])
torch.jit.save(traced, "tmodel.pt")
tmodel = torch.jit.load("tmodel.pt")
print("loaded traced model")
test(tmodel)
Ran in a nvcr.io/nvidia/pytorch:22.10-py3 docker container on a compute server (8xV100, cuda driver 525.60.13), this was the only task running, results are (without putting the model in eval mode, the time was ~65s):
scripting: 1.7816369533538818
process: 0.6788175106048584
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1185: UserWarning: operator() profile_node %201 : bool = prim::profile_ivalue(%training.22)
does not have profile information (Triggered internally at /opt/pytorch/pytorch/torch/csrc/jit/codegen/cuda/graph_fuser.cpp:105.)
return forward_call(*input, **kwargs)
process: 56.04067301750183
process: 0.008050203323364258
process: 0.006029844284057617
loaded torchscript
process: 0.44356417655944824
process: 40.395591020584106
process: 0.0060329437255859375
process: 0.0053980350494384766
froze model: 2.6882169246673584
loaded frozen model
process: 2.517376661300659
process: 5.002000093460083
process: 0.004984378814697266
process: 0.004503011703491211
loaded traced model
process: 0.19497919082641602
process: 0.3333449363708496
process: 0.005604267120361328
process: 0.005307674407958984
Frozen model warms up significantly faster, but I would still like to eliminate overhead.
Tracing is almost perfect, but I guess that way one has to handle traces with different inputs. Also, when the model is a ScriptModule itself, this breaks according to the warning:
torch.jit.trace(smodel, [data])
d:\conda\envs\dev\lib\site-packages\torch\jit\_trace.py:744: UserWarning: The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is.
I was hoping for something like exporting the model with all its optimized graphs, or even save/restore graphs e.g. extracted with lmodel.graph_for(data).
Thanks