I am trying to get a torchscript .pt file from a trained X3D network (a 3D CNN). One module
is causing the problem, in particular an implementation of the Swish activation given by:
import torch
import torch.nn as nn
class Swish(nn.Module):
"""
Wrapper for the Swish activation function.
"""
def forward(self, x):
return SwishFunction.apply(x)
class SwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
result = x * torch.sigmoid(x)
ctx.save_for_backward(x)
return result
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
sigmoid_x = torch.sigmoid(x)
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
When I try to script the X3D network, I get the error specific to this module, which has a custom back-pass implementation. The error reads:
Could not export Python function call 'SwishFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
Any thoughts how to fix this?