Hi
I tried to perform jit.trace
trace_model = torch.jit.trace(net,x)
trace_model.save('out.pt')
on a net that contains
import torch
from torch import nn as nn
from torch.nn import functional as F
__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe',
'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe']
@torch.jit.script
def swish_jit_fwd(x):
return x.mul(torch.sigmoid(x))
@torch.jit.script
def swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200
Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
and also as Swish (https://arxiv.org/abs/1710.05941).
TODO Rename to SiLU with addition to PyTorch
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return swish_jit_bwd(x, grad_output)
def swish_me(x, inplace=False):
return SwishJitAutoFn.apply(x)
class SwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishMe, self).__init__()
def forward(self, x):
return SwishJitAutoFn.apply(x)
While running
trace_model.save('out.pt')
it fails at ‘SwishJitAutoFn’ with following error message:
Traceback (most recent call last):
File "rk3566Test.py", line 269, in <module>
export_pytorch_model()
File "rk3566Test.py", line 56, in export_pytorch_model
trace_model.save('./sqnet_3566.pt')
File "/home/paul/rknn2/lib/python3.6/site-packages/torch/jit/__init__.py", line 1987, in save
return self._c.save(*args, **kwargs)
RuntimeError:
Could not export Python function call 'SwishJitAutoFn'. Remove calls to Python functions before export. Did you forget add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/home/paul/.cache/torch/hub/rwightman_gen-efficientnet-pytorch_master/geffnet/activations/activations_me.py(63): forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(720): _call_impl
/home/paul/pytorch/AdaBins/models/unet_adaptive_bins.py(73): forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(720): _call_impl
/home/paul/pytorch/AdaBins/models/unet_adaptive_bins.py(93): forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(720): _call_impl
/home/paul/rknn2/lib/python3.6/site-packages/torch/jit/__init__.py(1109): trace_module
/home/paul/rknn2/lib/python3.6/site-packages/torch/jit/__init__.py(955): trace
rk3566Test.py(54): export_pytorch_model
rk3566Test.py(269): <module>
Am I missing something for the trace_model.save() or the SwishJitAutoFn code needs to be revised to be trace_model.save friendly?
Thank you very much for your help in advance.