Simple Conv2d Function cannot be scripted and reports Runtime Error.
Here is my simple Conv2d module, I want to script it using torch.jit.script
.
import torch
class Conv2dCell(torch.nn.Module):
def __init__(self):
super(Conv2dCell, self).__init__()
def forward(self, x):
conv = torch.nn.Conv2d(1, 3, 3, stride=1)
output = conv(x)
return output
m = Conv2dCell()
scripted_m = torch.jit.script(m)
Running this piece of code will give the following error message:
Traceback (most recent call last):
File “conv2d.py”, line 13, in
scripted_m = torch.jit.script(m)
File “/mnt/ssd/maxhy/py36/lib/python3.6/site-packages/torch/jit/init.py”, line 1261, in script
return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
File “/mnt/ssd/maxhy/py36/lib/python3.6/site-packages/torch/jit/_recursive.py”, line 305, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File “/mnt/ssd/maxhy/py36/lib/python3.6/site-packages/torch/jit/_recursive.py”, line 361, in create_script_module_impl
create_methods_from_stubs(concrete_type, stubs)
File “/mnt/ssd/maxhy/py36/lib/python3.6/site-packages/torch/jit/_recursive.py”, line 279, in create_methods_from_stubs
concrete_type._create_methods(defs, rcbs, defaults)
File “/mnt/ssd/maxhy/py36/lib/python3.6/site-packages/torch/jit/init.py”, line 1108, in _compile_and_register_class
_jit_script_class_compile(qualified_name, ast, rcb)
RuntimeError:
Arguments for call are not valid.
The following variants are available:_pair(float[2] x) → (float):
Expected a value of type ‘List[float]’ for argument ‘x’ but instead found type ‘Tensor’._pair(int[2] x) → (int):
Expected a value of type ‘List[int]’ for argument ‘x’ but instead found type ‘Tensor’.The original call is:
File “/mnt/ssd/maxhy/py36/lib/python3.6/site-packages/torch/nn/modules/conv.py”, line 336
padding=0, dilation=1, groups=1,
bias=True, padding_mode=‘zeros’):
kernel_size = _pair(kernel_size)
~~~~~ <— HERE
stride = _pair(stride)
padding = _pair(padding)
‘Conv2d.init’ is being compiled since it was called from ‘Conv2d’
File “conv2d.py”, line 8
def forward(self, x):
conv = torch.nn.Conv2d(1, 3, 3, stride=1)
~~~~~~~~~~~~~~~ <— HERE
output = conv(x)
return output
‘Conv2d’ is being compiled since it was called from ‘Conv2dCell.forward’
File “conv2d.py”, line 8
def forward(self, x):
conv = torch.nn.Conv2d(1, 3, 3, stride=1)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <— HERE
output = conv(x)
return output
I am using PyTorch 1.5.1 and python 3.6.13, could someone help me identify the problem?