Hello!
I am trying to save a model .pt for an inference using torch.jit.trace but I keep getting the same error over and over again. I try to save it so I can use it in libtorch.
I am working on a part on monodepth project where I try to predict disparity on a single image. I am working with the mono_640x192 pretrained model.
Here is the part of the code where I try to save the model:
print(" Loading pretrained encoder")
encoder = networks.ResnetEncoder(18, False)
loaded_dict_enc = torch.load(encoder_path, map_location=device)
# extract the height and width of image that this model was trained with
feed_height = loaded_dict_enc['height']
feed_width = loaded_dict_enc['width']
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
encoder.load_state_dict(filtered_dict_enc)
encoder.to(device)
encoder.eval()
print(" Loading pretrained decoder")
depth_decoder = networks.DepthDecoder(
num_ch_enc=encoder.num_ch_enc, scales=range(4))
loaded_dict = torch.load(depth_decoder_path, map_location=device)
depth_decoder.load_state_dict(loaded_dict)
depth_decoder.to(device)
depth_decoder.eval()
# FINDING INPUT IMAGES
if os.path.isfile(args.image_path):
# Only testing on a single image
paths = [args.image_path]
output_directory = os.path.dirname(args.image_path)
elif os.path.isdir(args.image_path):
# Searching folder for images
paths = glob.glob(os.path.join(args.image_path, '*.{}'.format(args.ext)))
output_directory = args.image_path
else:
raise Exception("Can not find args.image_path: {}".format(args.image_path))
print("-> Predicting on {:d} test images".format(len(paths)))
# PREDICTING ON EACH IMAGE IN TURN
with torch.no_grad():
for idx, image_path in enumerate(paths):
if image_path.endswith("_disp.jpg"):
# don't try to predict disparity for a disparity image!
continue
# Load image and preprocess
input_image = pil.open(image_path).convert('RGB')
original_width, original_height = input_image.size
input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS)
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
# PREDICTION
input_image = input_image.to(device)
monodepth_model = MonodepthWrapper(encoder, depth_decoder)
# Save traced model
save_traced_model(monodepth_model, input_image, "traced_monodepth_model.pt")
Here is the wrapper
class MonodepthWrapper(torch.nn.Module):
def init(self, encoder, depth_decoder):
super(MonodepthWrapper, self).init()
self.encoder = encoder
self.depth_decoder = depth_decoder
def forward(self, x):
features = self.encoder(x)
outputs = self.depth_decoder(features)
return outputs
Save function
def save_traced_model(model, input_tensor, output_file):
traced_model = torch.jit.trace(model, input_tensor)
torch.jit.save(traced_model, output_file)
Here is the encoder foward function:
def forward(self, input_image):
self.features =
x = (input_image - 0.45) / 0.225
x = self.encoder.conv1(x)
x = self.encoder.bn1(x)
self.features.append(self.encoder.relu(x))
self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
self.features.append(self.encoder.layer2(self.features[-1]))
self.features.append(self.encoder.layer3(self.features[-1]))
self.features.append(self.encoder.layer4(self.features[-1]))
return self.features
And the decoder one
def forward(self, input_features, *args, **kwargs):
self.outputs = {}
# decoder
x = input_features[-1]
for i in range(4, -1, -1):
x = self.convs[("upconv", i, 0)](x)
x = upsample(x)
if self.use_skips and i > 0:
x = torch.cat((x, input_features[i - 1]), dim=1) # Concatenate input_features[i - 1] to tensor x
x = torch.cat((x,), 1)
x = self.convs[("upconv", i, 1)](x)
if i in self.scales:
self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
return self.outputs
This is the error that I get in the console:
What could cause the problem?