TypeError in torch.jit.trace(model, example)

ubuntu 16LTS,torch version:1.0.1.post2,-cpu version,Python 2.7.15 |Anaconda

here’s my codes:
import torch
import torchvision
import numpy as np
from importlib import import_module
from torch.nn import DataParallel
model = import_module(‘net_detector’)
config1, nod_net, loss, get_pbb = model.get_model()
checkpoint = torch.load(‘130.ckpt’)
model.load_state_dict(checkpoint[‘state_dict’])

example = np.load(’./236350_clean.npy’)
example=torch.from_numpy(example)

traced_script_module = torch.jit.trace(model, example)

I got error as follow:

Traceback (most recent call last):
File “convert2pt.py”, line 18, in
traced_script_module = torch.jit.trace(nod_net, example)
File “/home/qrf/anaconda2/lib/python2.7/site-packages/torch/jit/init.py”, line 636, in trace
var_lookup_fn, _force_outplace)
File “/home/qrf/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py”, line 487, in call
result = self._slow_forward(*input, **kwargs)
File “/home/qrf/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py”, line 477, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() takes exactly 3 arguments (2 given)

the presaved model was saved in pytorch ‘0.3.1’

Do you have a version of the model that is compatible with the current version of PyTorch? We do not guarantee that tracing works on models written for PyTorch 0.3

yes,thank you.maybe i should retrain my model using the stable version to figure out if it’s a “version” thing.

1 Like