I have trained a model using python and the network was defined as follow:
class ResidualFeatureNet(torch.nn.Module):
def __init__(self):
super(ResidualFeatureNet, self).__init__()
# Initial convolution layers
self.conv1 = ConvLayer(1, 32, kernel_size=5, stride=2)
self.conv2 = ConvLayer(32,64, kernel_size=3, stride=2)
self.conv3 = ConvLayer(64,128, kernel_size=3, stride=1)
self.resid1= ResidualBlock(128)
self.resid2= ResidualBlock(128)
self.resid3= ResidualBlock(128)
self.resid4= ResidualBlock(128)
self.conv4 = ConvLayer(128, 64, kernel_size=3, stride=1)
self.conv5 = ConvLayer(64, 1, kernel_size=1, stride=1)
def forward(self, X):
conv1 = F.relu(self.conv1(X))
conv2 = F.relu(self.conv2(conv1))
conv3 = F.relu(self.conv3(conv2))
resid1= self.resid1(conv3)
resid2= self.resid2(resid1)
resid3= self.resid3(resid2)
resid4= self.resid4(resid3)
conv4 = F.relu(self.conv4(resid4))
conv5 = F.relu(self.conv5(conv4))
return conv5
After that I convert the model into TorchScript model using following script:
import os
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from netdef_128 import ResidualFeatureNet as RFN
np.random.seed(0)
torch.manual_seed(0)
print(torch.__version__)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Pytorch model scripted')
parser.add_argument('--input', help='input model file')
parser.add_argument('--output', help='output model to save')
parser.add_argument('--imgsize', help='image size to predict')
parser.add_argument('--chan', type=int, help='the channel of image to predict')
args = parser.parse_args()
device = torch.device('cuda:0')
model = RFN()
#model = UNet(n_channels=3, n_classes=2)
model = model.to(device)
#model.load_state_dict(torch.load(args.input, map_location=str(device)))
state_dict = torch.load(args.input, map_location=str(device))
for k in list(state_dict.keys()):
if (k.find('running_mean')>0) or (k.find('running_var')>0):
del state_dict[k]
model.load_state_dict(state_dict)
model.eval()
width, height = args.imgsize.split(',')
example = torch.rand(1, args.chan, int(width), int(height)).to('cuda:0')
traced_script_module = torch.jit.trace(model, example)
out = traced_script_module(example)
traced_script_module.save(args.output)
But there is an issue when loading the TorchScript model, and the logs is listed as follow. so can anyone help me to solve this issue? Thanks
what(): forward() Expected a value of type ‘Tensor’ for argument ‘input’ but instead found type ‘List[Tensor]’.
Position: 1
Declaration: forward(torch.netdef_128.ResidualFeatureNet self, Tensor input) → (Tensor)
Exception raised from checkArg at /pytorch/aten/src/ATen/core/function_schema_inl.h:162 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f383c7252f2 in /usr/local/libtorch1.8/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5b (0x7f383c72267b in /usr/local/libtorch1.8/lib/libc10.so)
frame #2: + 0xd9668f (0x7f382b32568f in /usr/local/libtorch1.8/lib/libtorch_cpu.so)
frame #3: torch::jit::GraphFunction::operator()(std::vector<c10::IValue, std::allocatorc10::IValue >, std::unordered_map<std::string, c10::IValue, std::hashstd::string, std::equal_tostd::string, std::allocator<std::pair<std::string const, c10::IValue> > > const&) + 0x2d (0x7f382d7cb88d in /usr/local/libtorch1.8/lib/libtorch_cpu.so)
frame #4: torch::jit::Method::operator()(std::vector<c10::IValue, std::allocatorc10::IValue >, std::unordered_map<std::string, c10::IValue, std::hashstd::string, std::equal_tostd::string, std::allocator<std::pair<std::string const, c10::IValue> > > const&) + 0x138 (0x7f382d7d91e8 in /usr/local/libtorch1.8/lib/libtorch_cpu.so)
frame #5: + 0x27fa4 (0x559b66363fa4 in /home/kevin/palmGPU/Release/Matching)
frame #6: + 0x1c17b (0x559b6635817b in /home/kevin/palmGPU/Release/Matching)
frame #7: + 0x20d5f (0x559b6635cd5f in /home/kevin/palmGPU/Release/Matching)
frame #8: + 0xa189d (0x7f382941989d in /home/kevin/anaconda3/lib/libQt5Core.so.5)
frame #9: + 0x9609 (0x7f3828e8d609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #10: clone + 0x43 (0x7f3828fcf293 in /lib/x86_64-linux-gnu/libc.so.6)