a.py
import torch
import torch.fx
from math import sqrt
def normalize(x):
return x / sqrt(len(x))
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
setup.py
from distutils.core import setup
from Cython.Build import cythonize
setup(name='a', ext_modules=cythonize(['a.py']))
python setup.py bulid_ext --inplace
b.py
from a import *
normalize(torch.randn(3, 4))
traced = torch.fx.symbolic_trace(normalize)
python b.py
error from _symbolic_trace.py
AssertionError
How to deal it?