I was trying Pytorch 2.0 and found that torch. jit.trace
could not be convert torch.compile’s model. For example:
import torch
from torchvision.models import resnet50
model = resnet50(weights=None)
compile_model = torch.compile(model)
example_forward_input = torch.rand(1, 3, 224, 224)
c_model_traced = torch.jit.trace(compile_model, example_forward_input)
torch.jit.save(c_model_traced, "c_trace_model.pt")