Profiler crashing: "RuntimeError: can't export a trace that didn't finish running"

I have a created a neural network that is for some reason running extremely slow (especially in the backward part which takes ~x40 the forward pass), so I decided to try using the profiler on it. I’m currently using it like this, which I have basically taken straight from the profiler documentation:

    with profiler.profile(record_shapes=True) as prof:
        with profiler.record_function("model_inference"):
            node_vec = model(input=coords_init,xn_attr=node_attr)
        print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

However when I use this I get the following error:

Traceback (most recent call last):
  File "/home/tue/PycharmProjects/GraphNetworks/src/protein_utils.py", line 41, in use_proteinmodel
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
  File "/home/tue/PycharmProjects/GraphNetworks/e3nn_cpu/lib/python3.8/site-packages/torch/autograd/profiler.py", line 552, in key_averages
    self._check_finish()
  File "/home/tue/PycharmProjects/GraphNetworks/e3nn_cpu/lib/python3.8/site-packages/torch/autograd/profiler.py", line 525, in _check_finish
    raise RuntimeError("can't export a trace that didn't finish running")
RuntimeError: can't export a trace that didn't finish running

The mode runs fine (but a bit slow) when I don’t use the on it. The error seems to suggest that the model hasn’t finished running, which I don’t really understand? and unfortunately I haven’t been able to find anyone else encountering this error.

I’m currently running this on my laptop using CPU, with the following:

(e3nn_cpu) tue@tue-laptop:~/PycharmProjects/GraphNetworks$ pip list
Package               Version
--------------------- ---------
ase                   3.21.1
certifi               2020.12.5
chardet               4.0.0
cycler                0.10.0
decorator             4.4.2
e3nn                  0.2.7
googledrivedownloader 0.4
h5py                  3.2.1
idna                  2.10
isodate               0.6.0
Jinja2                2.11.3
joblib                1.0.1
kiwisolver            1.3.1
llvmlite              0.36.0
MarkupSafe            1.1.1
matplotlib            3.4.1
mpmath                1.2.1
networkx              2.5.1
numba                 0.53.1
numpy                 1.20.2
pandas                1.2.4
Pillow                8.2.0
pip                   21.0.1
pyparsing             2.4.7
python-dateutil       2.8.1
python-louvain        0.15
pytz                  2021.1
rdflib                5.0.0
requests              2.25.1
scikit-learn          0.24.1
scipy                 1.6.2
setuptools            56.0.0
six                   1.15.0
sympy                 1.8
threadpoolctl         2.1.0
torch                 1.8.1+cpu
torch-cluster         1.5.9
torch-geometric       1.7.0
torch-scatter         2.0.6
torch-sparse          0.6.9
torch-spline-conv     1.2.1
torchaudio            0.8.1
torchvision           0.9.1+cpu
tqdm                  4.60.0
typing-extensions     3.7.4.3
urllib3               1.26.4

I figured out the issue, the issue is that the print statement needs to be outside both with statements, like this:

    with profiler.profile(record_shapes=True) as prof:
        with profiler.record_function("model_inference"):
            node_vec = model(input=coords_init,xn_attr=node_attr)
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

It would be nice if this was included in the documentation/tutorial.