ubuntu 16LTS,torch version:1.0.1.post2,-cpu version,Python 2.7.15 |Anaconda
here’s my codes:
import torch
import torchvision
import numpy as np
from importlib import import_module
from torch.nn import DataParallel
model = import_module(‘net_detector’)
config1, nod_net, loss, get_pbb = model.get_model()
checkpoint = torch.load(‘130.ckpt’)
model.load_state_dict(checkpoint[‘state_dict’])
example = np.load(’./236350_clean.npy’)
example=torch.from_numpy(example)
traced_script_module = torch.jit.trace(model, example)
I got error as follow:
Traceback (most recent call last):
File “convert2pt.py”, line 18, in
traced_script_module = torch.jit.trace(nod_net, example)
File “/home/qrf/anaconda2/lib/python2.7/site-packages/torch/jit/init.py”, line 636, in trace
var_lookup_fn, _force_outplace)
File “/home/qrf/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py”, line 487, in call
result = self._slow_forward(*input, **kwargs)
File “/home/qrf/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py”, line 477, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() takes exactly 3 arguments (2 given)