Torch.jit.trace hangs

I’m calling torch.jit.trace() on my model and it appears to be stuck in a loop at
these lines

  File "/data/scratch/karima/anaconda3/envs/mother/lib/python3.7/difflib.py", line 1032, in _fancy_helper
    yield from g
  File "/data/scratch/karima/anaconda3/envs/mother/lib/python3.7/difflib.py", line 1020, in _fancy_replace
    yield from self._fancy_helper(a, best_i+1, ahi, b, best_j+1, bhi)
  File "/data/scratch/karima/anaconda3/envs/mother/lib/python3.7/difflib.py", line 1032, in _fancy_helper
    yield from g
  File "/data/scratch/karima/anaconda3/envs/mother/lib/python3.7/difflib.py", line 1020, in _fancy_replace
    yield from self._fancy_helper(a, best_i+1, ahi, b, best_j+1, bhi)
  File "/data/scratch/karima/anaconda3/envs/mother/lib/python3.7/difflib.py", line 1032, in _fancy_helper

my model is implemented as a DAG of python objects that inherit from nn.Module and call the .run() functions of their child nodes which in turn call the .forward() implementations like in the example below:

class InterleavedSumOp(nn.Module):
  def __init__(self, operand, C_out):
    super(InterleavedSumOp, self).__init__()
    self.C_out = C_out
    self._operands = nn.ModuleList([operand])
    self.output = None

  def _initialize_parameters(self):
    self._operands[0]._initialize_parameters()

  def run(self, model_inputs):
    if self.output is None:
      operand = self._operands[0].run(model_inputs)
      self.output = self.forward(operand)
    return self.output 

  def forward(self, x):
    x_shape = list(x.shape)
    x_reshaped = torch.reshape(x, (x_shape[0], x_shape[1]//self.C_out, self.C_out, x_shape[2], x_shape[3]))
    out = x_reshaped.sum(1, keepdim=False) 
    return out
    
  def to_gpu(self, gpu_id):
    self._operands[0].to_gpu(gpu_id)