RuntimeError second time running loaded module

Hello,

I’ve saved a model usint torch.jit.trace, and then loaded it using torch.jit.load.

The first time I run the model it outputs correctly, but the second and subsequent times, it throws this error

RuntimeError                              Traceback (most recent call last)
<ipython-input-25-946d12987c4d> in <module>
----> 1 out = model(inputs)

~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

RuntimeError: _Map_base::at

Any idea what could be causing it?

Checked with torch 1.8.0 and torch 1.10.0. Happens both on CPU and GPU

As the second pass apparently does the optimization, is there a way to disable this optimization, as that’s where it seems to be failing?