Input size disappears between torch.jit.save and torch.jit load

Hello! Apologies if this is a silly question, I’m very new to PyTorch - basically, I’m trying to save a traced model with torch.jit.save and then later load it with torch.jit.load and access the size of the input tensor, but it looks like the size information goes missing between saving and loading. If I run the snippet

import torch
import torchvision
model_name = "resnet18"

model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()

input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
traced_model = torch.jit.trace(model, input_data)

print(list(traced_model.graph.inputs()))
torch.jit.save(traced_model, "resnet18.pth")

loaded_model = torch.jit.load("resnet18.pth")
print(list(loaded_model.graph.inputs()))

the first print (corresponding to the model before saving) prints out

[self.1 defined in (%self.1 : __torch__.torchvision.models.resnet.ResNet, %input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu) = prim::Param()
), input.1 defined in (%self.1 : __torch__.torchvision.models.resnet.ResNet, %input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu) = prim::Param()
)]

which shows some familiar numbers about the inputs, but the second print (corresponding to the loaded model) prints

[self.1 defined in (%self.1 : __torch__.torchvision.models.resnet.ResNet, %input.1 : Tensor = prim::Param()
), input.1 defined in (%self.1 : __torch__.torchvision.models.resnet.ResNet, %input.1 : Tensor = prim::Param()
)]

The input numbers have gone missing… Does anyone know why this happens? Or maybe there is some other way of saving a model on a disk and loading it later such that the information about inputs is still accessible? I’m asking that because I’m working on an ahead of time compiler which relies on fetching the input shape from the model.

1 Like

Did you find out the solution? I am struggling to find the shape for this as well.