JIT compile staticmethods to a Module class

I want to export my model to ONNX, as a first step I want to compile my Module using @jit.script. My module contains one helper staticmethod to be used in inference. However when I call helper function in jit compiled model, I get the following error:

Traceback (most recent call last):
File “SimpleONNX.py”, line 28, in
print(model.return_ata(mat, x))
File “/Users/…/venv/lib/python3.7/site-packages/torch/jit/_script.py”, line 667, in getattr
return super(RecursiveScriptModule, self).getattr(attr)
File “/Users/…/venv/lib/python3.7/site-packages/torch/jit/_script.py”, line 384, in getattr
return super(ScriptModule, self).getattr(attr)
File “/Users/…/venv/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 1131, in getattr
type(self).name, name))
AttributeError: ‘RecursiveScriptModule’ object has no attribute ‘return_ata’

Given below is my python program

import torch
from torch import nn

class SimpleLinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.mat = torch.rand(3,3)

    def forward(self, x:torch.Tensor):
        return torch.matmul(self.mat, x)

    @staticmethod
    def return_ata(mat:torch.Tensor,x:torch.Tensor):
        matx = torch.matmul(mat,x)
        xt = torch.transpose(x,0,1)
        return torch.matmul(xt, matx)


x = torch.rand((3,1))
mat = torch.rand(3,3)

SL = SimpleLinear()

model = torch.jit.script(SL)
print(model(x))
print(model.return_ata(mat, x))

Is there anyway to keep return_ata as a member function? Also will it retain itself during ONNX runtime?


I can get it worked as class member with @torch.jit.export ie:


    @torch.jit.export
    def return_ata(self, mat:torch.Tensor,x:torch.Tensor):
        matx = torch.matmul(mat,x)
        xt = torch.transpose(x,0,1)
        return torch.matmul(xt, matx)