Hello! Apologies if this is a silly question, I’m very new to PyTorch - basically, I’m trying to save a traced model with torch.jit.save
and then later load it with torch.jit.load
and access the size of the input tensor, but it looks like the size information goes missing between saving and loading. If I run the snippet
import torch
import torchvision
model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
traced_model = torch.jit.trace(model, input_data)
print(list(traced_model.graph.inputs()))
torch.jit.save(traced_model, "resnet18.pth")
loaded_model = torch.jit.load("resnet18.pth")
print(list(loaded_model.graph.inputs()))
the first print (corresponding to the model before saving) prints out
[self.1 defined in (%self.1 : __torch__.torchvision.models.resnet.ResNet, %input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu) = prim::Param()
), input.1 defined in (%self.1 : __torch__.torchvision.models.resnet.ResNet, %input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu) = prim::Param()
)]
which shows some familiar numbers about the inputs, but the second print (corresponding to the loaded model) prints
[self.1 defined in (%self.1 : __torch__.torchvision.models.resnet.ResNet, %input.1 : Tensor = prim::Param()
), input.1 defined in (%self.1 : __torch__.torchvision.models.resnet.ResNet, %input.1 : Tensor = prim::Param()
)]
The input numbers have gone missing… Does anyone know why this happens? Or maybe there is some other way of saving a model on a disk and loading it later such that the information about inputs is still accessible? I’m asking that because I’m working on an ahead of time compiler which relies on fetching the input shape from the model.