I have a weight model named model.pt for brain segmentation from head CT scan.
How can i convert this into torchscript file so that i can use the model for deployment,
Network defn:
3dUNet,
in channel: 1(image),
out channel: 2(brain label and background)
Input defn:
"image": {
"type": "image",
"format": "hounsfield",
"modality": "CT",
"num_channels": 1,
"spatial_shape": [
96,
96,
96
],
"dtype": "float32",
"value_range": [
0,
1
],
"is_patch_data": true,
"channel_def": {
"0": "image"
}
}
},
Train/val split: 13 images for training and 3 for validation
Output defn:
"pred": {
"type": "image",
"format": "segmentation",
"num_channels": 2,
"spatial_shape": [
96,
96,
96
],
"dtype": "float32",
"value_range": [
0,
1
],
"is_patch_data": true,
"channel_def": {
"0": "background",
"1": "brain"
}
}
Now, how can i use tracing/scripting to convert into torchsctipt.
Are these pieces of information enough?
I tried
import torch
model = torch.load('model/model.pt')
example = torch.rand(13, 96, 96, 96)
traced_script_module = torch.jit.script(model, (example))
torch.save(traced_script_module, "model/traced_resnet_model.ts")
I only used the model input size, I also tried torch.jit.trace
. But both failed.
Any help would be very appreciable.