Conversion of PyTorch,pt model file into torchScript.ts file

I have a weight model named model.pt for brain segmentation from head CT scan.
How can i convert this into torchscript file so that i can use the model for deployment,

Network defn:

3dUNet, 
in channel: 1(image), 
out channel: 2(brain label and background)

Input defn:

 "image": {
                "type": "image",
                "format": "hounsfield",
                "modality": "CT",
                "num_channels": 1,
                "spatial_shape": [
                    96,
                    96,
                    96
                ],
                "dtype": "float32",
                "value_range": [
                    0,
                    1
                ],
                "is_patch_data": true,
                "channel_def": {
                    "0": "image"
                }
            }
        },

Train/val split: 13 images for training and 3 for validation

Output defn:

          "pred": {
              "type": "image",
              "format": "segmentation",
              "num_channels": 2,
              "spatial_shape": [
                  96,
                  96,
                  96
              ],
              "dtype": "float32",
              "value_range": [
                  0,
                  1
              ],
              "is_patch_data": true,
              "channel_def": {
                  "0": "background",
                  "1": "brain"
              }
          }
      

Now, how can i use tracing/scripting to convert into torchsctipt.
Are these pieces of information enough?

I tried

import torch

model = torch.load('model/model.pt')

example = torch.rand(13, 96, 96, 96)

traced_script_module = torch.jit.script(model, (example))
torch.save(traced_script_module, "model/traced_resnet_model.ts")

I only used the model input size, I also tried torch.jit.trace. But both failed.
Any help would be very appreciable.

Could you provide more information on the model.pt file? Could you also provide us with the error? It is hard to tell what could be wrong if we do not see the error message.

Anyways, my guess would be that your model.pt file contains only the model weights, e.g., saved by

torch.save(model.state_dict(), PATH)

Is that assumption correct? If so, the file does not contain information about the actual model architecture, and only contains the trained weights. You should instantiate the model and then load the weights. Afterwards, you can attempt to convert it to jit script.

1 Like

Thanks @Paplham for your reply
Yes model.pt contains only the model weights and the error i got when i deployed the model using monai-deploy-app-sdk was:
# ItemNotExistsError: A predictor of the model is not set.

If I instantiate the model and then load the weights. Then what should i do to convert it to jit script?

First, you need to instantiate the model and load the weights. Please, see Saving and Loading Models — PyTorch Tutorials 2.0.1+cu117 documentation for this.

I will also provide you with an example below, where I will be using a resnet18 from torchvision.

First, we instantiate the model:

import torch
import torch.nn as nn
import torchvision
from torchvision.models import resnet18

model = resnet18(pretrained=False)

You can then train the model and afterwards, save the trained weights as:

torch.save(model.state_dict(), 'trained_weights.pt')

You have probably already done those steps. Now, we want to load the model and export it to torchscript. For this, you need to instantiate the model again and load the saved weights into it:

model = resnet18(pretrained=False)
model.load_state_dict(torch.load('trained_weights.pt'))
model.eval()

If your model uses standard PyTorch operations, exporting it should be as simple as calling:

model_scripted = torch.jit.script(model) # Export to TorchScript
torch.jit.save(model_scripted, 'exported_model.pt')

Please, see torch.jit.script — PyTorch 2.0 documentation for more details.

1 Like

Thank you so much @Paplham
I have done this

import torch
from monai.networks.nets import UNet

model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=[16, 32, 64, 128, 256],
            strides=[2, 2, 2, 2],
            num_res_units=2,
            norm="batch",
        )

model.load_state_dict(torch.load('model/model.pt'))

model.eval()

# example = torch.rand(13, 96, 96, 96)

model_scripted = torch.jit.script(model)
torch.jit.save(model_scripted, 'model/torchscript_model.ts')

Dont we need the model input shape for passing into torch.script?

Depends on the model that you are using. Generally, you do not need to provide the input shape. Perhaps, you are thinking of tracing?

1 Like

Yes, what is the actual difference between scripting and tracing. Why does tracing need input shape and scripting donot?

There is a long discussion to be had about that subject, if you are interested in it, a lot can be found in the documentation.

However, basic intuition of the two is as follows:

Scripting will inspect the source code and compile it, also saving the needed weights and possibly also functions of the model (if wrapped in an export decorator and possibly much more can be done!).

Tracing is done by passing input to the model and remembering which operations were done to the input in order to get the output. These operations are then compiled and saved.

What is the practical difference? Well, if your model’s forward pass contains an if else statement, then what will happen if we script vs. trace?

If we trace the model, only one condition from the if/else will be executed and the exported model will always perform the same operation. If we script the model, the if/else condition will be checked and possibly different operations will be run.

There are more differences, but I think that the if/else example showcases well what the differences are.

1 Like

Thank you @Paplham for your help. My issue is solved.

1 Like

Glad I could be of help :slight_smile:

1 Like