I’m calling torch.jit.trace() on my model and it appears to be stuck in a loop at
these lines
File "/data/scratch/karima/anaconda3/envs/mother/lib/python3.7/difflib.py", line 1032, in _fancy_helper
yield from g
File "/data/scratch/karima/anaconda3/envs/mother/lib/python3.7/difflib.py", line 1020, in _fancy_replace
yield from self._fancy_helper(a, best_i+1, ahi, b, best_j+1, bhi)
File "/data/scratch/karima/anaconda3/envs/mother/lib/python3.7/difflib.py", line 1032, in _fancy_helper
yield from g
File "/data/scratch/karima/anaconda3/envs/mother/lib/python3.7/difflib.py", line 1020, in _fancy_replace
yield from self._fancy_helper(a, best_i+1, ahi, b, best_j+1, bhi)
File "/data/scratch/karima/anaconda3/envs/mother/lib/python3.7/difflib.py", line 1032, in _fancy_helper
my model is implemented as a DAG of python objects that inherit from nn.Module and call the .run() functions of their child nodes which in turn call the .forward() implementations like in the example below:
class InterleavedSumOp(nn.Module):
def __init__(self, operand, C_out):
super(InterleavedSumOp, self).__init__()
self.C_out = C_out
self._operands = nn.ModuleList([operand])
self.output = None
def _initialize_parameters(self):
self._operands[0]._initialize_parameters()
def run(self, model_inputs):
if self.output is None:
operand = self._operands[0].run(model_inputs)
self.output = self.forward(operand)
return self.output
def forward(self, x):
x_shape = list(x.shape)
x_reshaped = torch.reshape(x, (x_shape[0], x_shape[1]//self.C_out, self.C_out, x_shape[2], x_shape[3]))
out = x_reshaped.sum(1, keepdim=False)
return out
def to_gpu(self, gpu_id):
self._operands[0].to_gpu(gpu_id)