How to export/cache graph/bytecode for torchscript models?

I am using a model which I saved as torchscript model (torch.jit.script), but every time I load the model it takes over a minute to warm up (profiling/optimization done in the first 2 invocations).
Is there a way to somehow save not just the model but the optimized graph/bytecode (possibly with multiple input shapes)? Also is there an option to cache the graph/bytecode as new input shapes show up?

Is the code stuck in these two forward passes for over a minute?
This sounds quite unexpected. Would you be able to share the model definition so that we could take a look at it?
Some internal utils. are able to cache code-generated kernels, such as the Jiterator, but I don’t know if e.g. torch.compile is already doing it (I don’t think so, but might be mistaken),

Unfortunately I can’t share my model, but notice the same with a stock model. A code snippet measuring processing time of the densenet121 model:

import torch
import torchvision
from time import time

model = torchvision.models.densenet121(weights=torchvision.models.densenet.DenseNet121_Weights.DEFAULT)

data = torch.randn(1,3,224,224).cuda()

t = time()
smodel = torch.jit.script(model)
print("scripting: ", time()-t)

def test(model):
	with torch.no_grad():
		for _ in range(4):
			t = time()
			out = model(data)
			print("process: ", time()-t)
test(smodel), "")
lmodel = torch.jit.load("")

print("loaded torchscript")

t = time()
frozen = torch.jit.freeze(smodel)
print("froze model: ", time()-t), "")
fmodel = torch.jit.load("")

print("loaded frozen model")

traced = torch.jit.trace(model, [data]), "")
tmodel = torch.jit.load("")

print("loaded traced model")

Ran in a docker container on a compute server (8xV100, cuda driver 525.60.13), this was the only task running, results are (without putting the model in eval mode, the time was ~65s):

scripting:  1.7816369533538818
process:  0.6788175106048584
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/ UserWarning: operator() profile_node %201 : bool = prim::profile_ivalue(%training.22)
 does not have profile information (Triggered internally at /opt/pytorch/pytorch/torch/csrc/jit/codegen/cuda/graph_fuser.cpp:105.)
  return forward_call(*input, **kwargs)
process:  56.04067301750183
process:  0.008050203323364258
process:  0.006029844284057617
loaded torchscript
process:  0.44356417655944824
process:  40.395591020584106
process:  0.0060329437255859375
process:  0.0053980350494384766
froze model:  2.6882169246673584
loaded frozen model
process:  2.517376661300659
process:  5.002000093460083
process:  0.004984378814697266
process:  0.004503011703491211
loaded traced model
process:  0.19497919082641602
process:  0.3333449363708496
process:  0.005604267120361328
process:  0.005307674407958984

Frozen model warms up significantly faster, but I would still like to eliminate overhead.
Tracing is almost perfect, but I guess that way one has to handle traces with different inputs. Also, when the model is a ScriptModule itself, this breaks according to the warning:

torch.jit.trace(smodel, [data])
d:\conda\envs\dev\lib\site-packages\torch\jit\ UserWarning: The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is.

I was hoping for something like exporting the model with all its optimized graphs, or even save/restore graphs e.g. extracted with lmodel.graph_for(data).