Cannot save model

Hello!
I am trying to save a model .pt for an inference using torch.jit.trace but I keep getting the same error over and over again. I try to save it so I can use it in libtorch.
I am working on a part on monodepth project where I try to predict disparity on a single image. I am working with the mono_640x192 pretrained model.
Here is the part of the code where I try to save the model:
print(" Loading pretrained encoder")
encoder = networks.ResnetEncoder(18, False)
loaded_dict_enc = torch.load(encoder_path, map_location=device)

# extract the height and width of image that this model was trained with
feed_height = loaded_dict_enc['height']
feed_width = loaded_dict_enc['width']
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
encoder.load_state_dict(filtered_dict_enc)
encoder.to(device)
encoder.eval()

print("   Loading pretrained decoder")
depth_decoder = networks.DepthDecoder(
    num_ch_enc=encoder.num_ch_enc, scales=range(4))

loaded_dict = torch.load(depth_decoder_path, map_location=device)
depth_decoder.load_state_dict(loaded_dict)

depth_decoder.to(device)
depth_decoder.eval()

# FINDING INPUT IMAGES
if os.path.isfile(args.image_path):
    # Only testing on a single image
    paths = [args.image_path]
    output_directory = os.path.dirname(args.image_path)
elif os.path.isdir(args.image_path):
    # Searching folder for images
    paths = glob.glob(os.path.join(args.image_path, '*.{}'.format(args.ext)))
    output_directory = args.image_path
else:
    raise Exception("Can not find args.image_path: {}".format(args.image_path))

print("-> Predicting on {:d} test images".format(len(paths)))

# PREDICTING ON EACH IMAGE IN TURN
with torch.no_grad():
    for idx, image_path in enumerate(paths):

        if image_path.endswith("_disp.jpg"):
            # don't try to predict disparity for a disparity image!
            continue

        # Load image and preprocess
        input_image = pil.open(image_path).convert('RGB')
        original_width, original_height = input_image.size
        input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS)
        input_image = transforms.ToTensor()(input_image).unsqueeze(0)

        # PREDICTION
        input_image = input_image.to(device)
        monodepth_model = MonodepthWrapper(encoder, depth_decoder)

        # Save traced model
        save_traced_model(monodepth_model, input_image, "traced_monodepth_model.pt")

Here is the wrapper
class MonodepthWrapper(torch.nn.Module):
def init(self, encoder, depth_decoder):
super(MonodepthWrapper, self).init()
self.encoder = encoder
self.depth_decoder = depth_decoder

def forward(self, x):
    features = self.encoder(x)
    outputs = self.depth_decoder(features)
    return outputs

Save function
def save_traced_model(model, input_tensor, output_file):
traced_model = torch.jit.trace(model, input_tensor)
torch.jit.save(traced_model, output_file)

Here is the encoder foward function:
def forward(self, input_image):
self.features =
x = (input_image - 0.45) / 0.225
x = self.encoder.conv1(x)
x = self.encoder.bn1(x)
self.features.append(self.encoder.relu(x))
self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
self.features.append(self.encoder.layer2(self.features[-1]))
self.features.append(self.encoder.layer3(self.features[-1]))
self.features.append(self.encoder.layer4(self.features[-1]))

    return self.features

And the decoder one
def forward(self, input_features, *args, **kwargs):
self.outputs = {}

    # decoder
    x = input_features[-1]
    for i in range(4, -1, -1):
        x = self.convs[("upconv", i, 0)](x)
        x = upsample(x)
        if self.use_skips and i > 0:
            x = torch.cat((x, input_features[i - 1]), dim=1)  # Concatenate input_features[i - 1] to tensor x
        x = torch.cat((x,), 1)


        x = self.convs[("upconv", i, 1)](x)
        if i in self.scales:
            self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
    return self.outputs

This is the error that I get in the console:


What could cause the problem?

I guess this line of code fails:

self.outputs[("disp", i)] = self.sigmoid(...)

as the tuple input might not be supported.
Did you try to create a single str as the key for the dict?
Also note that TorchScript is in maintenance mode and the general recommendation is to use torch.compile.

Thank you! I didn’t try this solution.
I created a single str as the key, but I got another Runtime Error:
RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module’s inputs. Consider using a constant container instead (e.g. for list, use a tuple instead. for dict, use a NamedTuple instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
I passed strict = False here

def save_traced_model(model, input_tensor, output_file):
    traced_model = torch.jit.trace(model, input_tensor, strict=False)
    torch.jit.save(traced_model, output_file)

and now it works