Rwightman / gen-efficientnet-pytorch : Could not export Python function call

Hi

I tried to perform jit.trace

trace_model = torch.jit.trace(net,x)
trace_model.save('out.pt')

on a net that contains

import torch
from torch import nn as nn
from torch.nn import functional as F


__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe',
           'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe']


@torch.jit.script
def swish_jit_fwd(x):
    return x.mul(torch.sigmoid(x))


@torch.jit.script
def swish_jit_bwd(x, grad_output):
    x_sigmoid = torch.sigmoid(x)
    return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))


class SwishJitAutoFn(torch.autograd.Function):
    """ torch.jit.script optimised Swish w/ memory-efficient checkpoint
    Inspired by conversation btw Jeremy Howard & Adam Pazske
    https://twitter.com/jeremyphoward/status/1188251041835315200

    Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
    and also as Swish (https://arxiv.org/abs/1710.05941).

    TODO Rename to SiLU with addition to PyTorch
    """

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return swish_jit_fwd(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        return swish_jit_bwd(x, grad_output)


def swish_me(x, inplace=False):
    return SwishJitAutoFn.apply(x)


class SwishMe(nn.Module):
    def __init__(self, inplace: bool = False):
        super(SwishMe, self).__init__()

    def forward(self, x):
        return SwishJitAutoFn.apply(x)

While running
trace_model.save('out.pt')
it fails at ‘SwishJitAutoFn’ with following error message:


Traceback (most recent call last):
  File "rk3566Test.py", line 269, in <module>
    export_pytorch_model()
  File "rk3566Test.py", line 56, in export_pytorch_model
    trace_model.save('./sqnet_3566.pt')
  File "/home/paul/rknn2/lib/python3.6/site-packages/torch/jit/__init__.py", line 1987, in save
    return self._c.save(*args, **kwargs)
RuntimeError: 
Could not export Python function call 'SwishJitAutoFn'. Remove calls to Python functions before export. Did you forget add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/home/paul/.cache/torch/hub/rwightman_gen-efficientnet-pytorch_master/geffnet/activations/activations_me.py(63): forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(720): _call_impl
/home/paul/pytorch/AdaBins/models/unet_adaptive_bins.py(73): forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(720): _call_impl
/home/paul/pytorch/AdaBins/models/unet_adaptive_bins.py(93): forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/paul/rknn2/lib/python3.6/site-packages/torch/nn/modules/module.py(720): _call_impl
/home/paul/rknn2/lib/python3.6/site-packages/torch/jit/__init__.py(1109): trace_module
/home/paul/rknn2/lib/python3.6/site-packages/torch/jit/__init__.py(955): trace
rk3566Test.py(54): export_pytorch_model
rk3566Test.py(269): <module>

Am I missing something for the trace_model.save() or the SwishJitAutoFn code needs to be revised to be trace_model.save friendly?

Thank you very much for your help in advance.

Found the solution: change the following line setting to True:

# Set to True if wanting to use torch.jit.script on a model
_SCRIPTABLE = True

in config.py