Hi!
When I define the following functions ‘n_deriv_’ (calling ‘jac’ recursively) and ‘n_deriv__’ (calling ‘grad’ recursively in a for loop), they work without issues:
import torch
from torch.autograd.functional import jacobian as jac
from torch.autograd.functional import hessian as hes
from torch.autograd import grad
def t(*x):
return torch.tensor([*x], dtype=torch.float32, requires_grad=True)
def f(x, p=5):
return torch.pow(x, p).sum()
print(f(t(2)))
def n_deriv_(func, x):
g = {'0': func}
g['1'] = lambda y: jac(g['0'], y, create_graph=True)
g['2'] = lambda y: jac(g['1'], y, create_graph=True)
return g['2'](x)
def n_deriv__(func, x, n=2):
g = {'0': func(x)}
for i in range(n):
g[str(i+1)] = grad(g[str(i)], x, create_graph=True)
return g[str(n)]
print(n_deriv_(f, t(1)))
print(n_deriv__(f, t(1)))
but defining and calling the function ‘n_deriv’ (calling ‘jac’ recursively in a for loop) like this
def n_deriv(func, x):
g = {'0': func}
for i in [0,1]:
g[str(i+1)] = lambda y: jac(g[str(i)], y, create_graph=True)
return g['2'](x)
print(n_deriv(f, t(1)))
leads to the following error
---------------------------------------------------------------------------
RecursionError Traceback (most recent call last)
<ipython-input-92-c87c3985ce83> in <module>
----> 1 n_deriv(f, t(1))
<ipython-input-90-7ba1805d2642> in n_deriv(func, x)
5 g[str(i+1)] = lambda y: jac(g[str(i)], y, create_graph=True)
6
----> 7 return g['2'](x)
8
9
<ipython-input-90-7ba1805d2642> in <lambda>(y)
3
4 for i in [0,1]:
----> 5 g[str(i+1)] = lambda y: jac(g[str(i)], y, create_graph=True)
6
7 return g['2'](x)
~/miniforge3/envs/dev/lib/python3.9/site-packages/torch/autograd/functional.py in jacobian(func, inputs, create_graph, strict, vectorize)
472 inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
473
--> 474 outputs = func(*inputs)
475 is_outputs_tuple, outputs = _as_tuple(outputs,
476 "outputs of the user-provided function",
... last 2 frames repeated, from the frame below ...
<ipython-input-90-7ba1805d2642> in <lambda>(y)
3
4 for i in [0,1]:
----> 5 g[str(i+1)] = lambda y: jac(g[str(i)], y, create_graph=True)
6
7 return g['2'](x)
RecursionError: maximum recursion depth exceeded while calling a Python object```
My installed packages are (conda list)
absl-py 0.12.0 pyhd8ed1ab_0 conda-forge
aiohttp 3.7.4 py39h46acfd9_0 conda-forge
appnope 0.1.2 py39h2804cbe_1 conda-forge
argon2-cffi 20.1.0 py39h46acfd9_2 conda-forge
astroid 2.5.1 pypi_0 pypi
async-timeout 3.0.1 py_1000 conda-forge
async_generator 1.10 py_0 conda-forge
attrs 21.2.0 pyhd8ed1ab_0 conda-forge
autopep8 1.5.5 pypi_0 pypi
backcall 0.2.0 pyh9f0ad1d_0 conda-forge
backports 1.0 py_2 conda-forge
backports.functools_lru_cache 1.6.4 pyhd8ed1ab_0 conda-forge
bleach 3.3.0 pyh44b312d_0 conda-forge
blinker 1.4 py_1 conda-forge
brotlipy 0.7.0 py39h46acfd9_1001 conda-forge
c-ares 1.17.1 h27ca646_1 conda-forge
ca-certificates 2021.5.30 h4653dfc_0 conda-forge
cachetools 4.2.2 pyhd8ed1ab_0 conda-forge
certifi 2021.5.30 py39h2804cbe_0 conda-forge
cffi 1.14.5 py39h702c04f_0 conda-forge
chardet 4.0.0 py39h2804cbe_1 conda-forge
click 7.1.2 pyh9f0ad1d_0 conda-forge
coverage 5.5 pypi_0 pypi
cryptography 3.4.6 py39h73257c9_0 conda-forge
cycler 0.10.0 pypi_0 pypi
cython 0.29.23 pypi_0 pypi
dataclasses 0.8 pyhc8e2a94_1 conda-forge
decorator 5.0.9 pyhd8ed1ab_0 conda-forge
defusedxml 0.7.1 pyhd8ed1ab_0 conda-forge
entrypoints 0.3 pyhd8ed1ab_1003 conda-forge
filelock 3.0.12 pyh9f0ad1d_0 conda-forge
flake8 3.8.4 pypi_0 pypi
freetype 2.10.4 h17b34a0_1 conda-forge
future 0.18.2 py39h2804cbe_3 conda-forge
google-auth 1.30.0 pyh44b312d_0 conda-forge
google-auth-oauthlib 0.4.1 py_2 conda-forge
grpcio 1.38.0 py39h9e1b6db_0 conda-forge
idna 2.10 pyh9f0ad1d_0 conda-forge
importlib-metadata 4.0.1 py39h2804cbe_0 conda-forge
ipykernel 5.5.5 py39h32adebf_0 conda-forge
ipython 7.23.1 py39h32adebf_0 conda-forge
ipython_genutils 0.2.0 py_1 conda-forge
ipywidgets 7.6.3 pyhd3deb0d_0 conda-forge
isort 5.7.0 pypi_0 pypi
jedi 0.17.2 pypi_0 pypi
jinja2 3.0.1 pyhd8ed1ab_0 conda-forge
joblib 1.0.1 pyhd8ed1ab_0 conda-forge
jpeg 9d h27ca646_0 conda-forge
jsonschema 3.2.0 pyhd8ed1ab_3 conda-forge
jupyter_client 6.1.12 pyhd8ed1ab_0 conda-forge
jupyter_core 4.7.1 py39h2804cbe_0 conda-forge
jupyterlab_pygments 0.1.2 pyh9f0ad1d_0 conda-forge
jupyterlab_widgets 1.0.0 pyhd8ed1ab_1 conda-forge
kiwisolver 1.3.1 pypi_0 pypi
lazy-object-proxy 1.5.2 pypi_0 pypi
lcms2 2.12 had6a04f_0 conda-forge
libblas 3.9.0 8_openblas conda-forge
libcblas 3.9.0 8_openblas conda-forge
libcxx 11.1.0 h168391b_0 conda-forge
libffi 3.3 h9f76cd9_2 conda-forge
libgfortran 5.0.0.dev0 11_0_1_hf114ba7_20 conda-forge
libgfortran5 11.0.1.dev0 hf114ba7_20 conda-forge
liblapack 3.9.0 8_openblas conda-forge
libopenblas 0.3.12 openmp_h2ecc587_1 conda-forge
libpng 1.6.37 hf7e6567_2 conda-forge
libprotobuf 3.15.6 habe5f53_0 conda-forge
libsodium 1.0.18 h27ca646_1 conda-forge
libtiff 4.2.0 h70663a0_0 conda-forge
libwebp-base 1.2.0 h27ca646_2 conda-forge
littleutils 0.2.2 pypi_0 pypi
llvm-openmp 11.1.0 hb3022d6_0 conda-forge
lz4-c 1.9.3 h9f76cd9_0 conda-forge
markdown 3.3.4 pyhd8ed1ab_0 conda-forge
markupsafe 2.0.1 py39h5161555_0 conda-forge
matplotlib 3.4.2 pypi_0 pypi
matplotlib-inline 0.1.2 pyhd8ed1ab_2 conda-forge
mccabe 0.6.1 pypi_0 pypi
mistune 0.8.4 py39h46acfd9_1003 conda-forge
multidict 5.1.0 py39h46acfd9_1 conda-forge
nb_conda 2.2.1 py39h2804cbe_4 conda-forge
nb_conda_kernels 2.3.1 py39h2804cbe_0 conda-forge
nbclient 0.5.3 pyhd8ed1ab_0 conda-forge
nbconvert 6.0.7 py39h2804cbe_3 conda-forge
nbformat 5.1.3 pyhd8ed1ab_0 conda-forge
ncurses 6.2 h9aa5885_4 conda-forge
nest-asyncio 1.5.1 pyhd8ed1ab_0 conda-forge
ninja 1.10.2 h4d860bb_0 conda-forge
notebook 6.4.0 pyha770c72_0 conda-forge
numpy 1.20.3 py39h1f3b974_1 conda-forge
oauthlib 3.1.0 pypi_0 pypi
ogb 1.3.1 pypi_0 pypi
olefile 0.46 pyh9f0ad1d_1 conda-forge
openssl 1.1.1k h27ca646_0 conda-forge
outdated 0.2.1 pypi_0 pypi
packaging 20.9 pyh44b312d_0 conda-forge
pandas 1.2.3 pypi_0 pypi
pandocfilters 1.4.2 py_1 conda-forge
parso 0.7.1 pypi_0 pypi
pexpect 4.8.0 pyh9f0ad1d_2 conda-forge
pickleshare 0.7.5 py_1003 conda-forge
pillow 8.1.2 py39hf007017_0 conda-forge
pip 21.0.1 pyhd8ed1ab_0 conda-forge
pluggy 0.13.1 pypi_0 pypi
prometheus_client 0.10.1 pyhd8ed1ab_0 conda-forge
prompt-toolkit 3.0.18 pyha770c72_0 conda-forge
protobuf 3.17.0 pypi_0 pypi
ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge
pyasn1 0.4.8 py_0 conda-forge
pyasn1-modules 0.2.8 pypi_0 pypi
pycodestyle 2.6.0 pypi_0 pypi
pycparser 2.20 pyh9f0ad1d_2 conda-forge
pydocstyle 6.0.0 pypi_0 pypi
pyflakes 2.2.0 pypi_0 pypi
pygments 2.9.0 pyhd8ed1ab_0 conda-forge
pyjwt 2.1.0 pyhd8ed1ab_0 conda-forge
pylint 2.7.2 pypi_0 pypi
pyopenssl 20.0.1 pyhd8ed1ab_0 conda-forge
pyparsing 2.4.7 pyh9f0ad1d_0 conda-forge
pyrsistent 0.17.3 py39h46acfd9_2 conda-forge
pysocks 1.7.1 py39h2804cbe_3 conda-forge
python 3.9.2 hcbd9b3a_0_cpython conda-forge
python-dateutil 2.8.1 py_0 conda-forge
python-jsonrpc-server 0.4.0 pypi_0 pypi
python-language-server 0.36.2 pypi_0 pypi
python_abi 3.9 1_cp39 conda-forge
pytorch 1.8.0 cpu_py39hff516c6_0 conda-forge
pytz 2021.1 pypi_0 pypi
pyzmq 22.0.3 py39h997613d_1 conda-forge
readline 8.0 hc8eb9b7_2 conda-forge
regex 2021.4.4 py39h5161555_0 conda-forge
requests 2.25.1 pyhd3deb0d_0 conda-forge
requests-oauthlib 1.3.0 pyh9f0ad1d_0 conda-forge
rope 0.18.0 pypi_0 pypi
rsa 4.7.2 pyh44b312d_0 conda-forge
sacremoses 0.0.43 pyh9f0ad1d_0 conda-forge
scikit-learn 0.24.1 py39hb966dd2_0 conda-forge
scipy 1.6.2 py39h5060c3b_0 conda-forge
seaborn 0.11.1 pypi_0 pypi
send2trash 1.5.0 py_0 conda-forge
setuptools 49.6.0 py39h2804cbe_3 conda-forge
six 1.15.0 pyh9f0ad1d_0 conda-forge
sleef 3.5.1 h27ca646_1 conda-forge
snowballstemmer 2.1.0 pypi_0 pypi
sqlite 3.34.0 h6d56c25_0 conda-forge
tensorboard 2.4.1 pyhd8ed1ab_0 conda-forge
tensorboard-data-server 0.6.1 pypi_0 pypi
tensorboard-plugin-wit 1.8.0 pyh44b312d_0 conda-forge
terminado 0.10.0 py39h2804cbe_0 conda-forge
testpath 0.5.0 pyhd8ed1ab_0 conda-forge
threadpoolctl 2.1.0 pyh5ca1d4c_0 conda-forge
tk 8.6.10 hf7e6567_1 conda-forge
tokenizers 0.10.1 py39hda0fb44_0 conda-forge
toml 0.10.2 pypi_0 pypi
torchvision 0.9.0 py39hc5dd9f3_0_cpu conda-forge
tornado 6.1 py39h46acfd9_1 conda-forge
tqdm 4.60.0 pyhd8ed1ab_0 conda-forge
traitlets 5.0.5 py_0 conda-forge
transformers 4.5.0 pyhd8ed1ab_0 conda-forge
typing-extensions 3.7.4.3 0 conda-forge
typing_extensions 3.7.4.3 py_0 conda-forge
tzdata 2021a he74cb21_0 conda-forge
ujson 4.0.2 pypi_0 pypi
urllib3 1.26.4 pyhd8ed1ab_0 conda-forge
wcwidth 0.2.5 pyh9f0ad1d_2 conda-forge
webencodings 0.5.1 py_1 conda-forge
werkzeug 2.0.1 pyhd8ed1ab_0 conda-forge
wheel 0.36.2 pyhd3deb0d_0 conda-forge
widgetsnbextension 3.5.1 py39h2804cbe_4 conda-forge
wilds 1.1.0 pypi_0 pypi
wrapt 1.12.1 pypi_0 pypi
xz 5.2.5 h642e427_1 conda-forge
yapf 0.31.0 pypi_0 pypi
yarl 1.6.3 py39h46acfd9_1 conda-forge
zeromq 4.3.4 h9f76cd9_0 conda-forge
zipp 3.4.1 pyhd8ed1ab_0 conda-forge
zlib 1.2.11 h31e879b_1009 conda-forge
zstd 1.4.9 h5b28eab_0 conda-forge