I’m tracing the mapping network between two generators with torch.jit.trace_module
Code example:
mapping_input = torch.rand(1, 64, 504, 378).to("cuda"), inst_data
inputs = {'forward' : encoder_forward_input, 'forward_encoder' : encoder_forward_input, 'forward_decoder' : decoder_forward_input}
inputs_mapping = {'forward' : mapping_input, 'inference_forward' : mapping_input}
traced_model_cuda = torch.jit.trace_module(self.mapping_net, inputs_mapping, strict=False)
torch.jit.save(traced_model_cuda, "scratch_mapping_net_traced_model_forward_strict_off.pt")
self.mapping_net = torch.jit.load("scratch_mapping_net_traced_model_forward_strict_off.pt")
self.mapping_net.cuda(self.opt.gpu_ids[0])
label_feat_map=self.mapping_net.inference_forward(label_feat.detach(), inst_data)
I’ve encountered this error when running traced graph and weights on a different image of the same size.
torch.Size([1, 1, 2016, 1512])
Skip HR_input_2016.png due to an error:
The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__/models/NonLocal_feature_mapping_model.py", line 231, in inference_forward
_120 = torch.slice(x_unfold, 0, 0, 9223372036854775807, 1)
_121 = torch.slice(_120, 1, 0, 9223372036854775807, 1)
_122 = torch.view(composed_unfold, [32768, 154])
~~~~~~~~~~ <--- HERE
mask_index5 = torch.to(mask_index, dtype=4, layout=0, device=torch.device("cuda:0"), pin_memory=None, non_blocking=False, copy=False, memory_format=None)
_123 = annotate(List[Optional[Tensor]], [None, None, mask_index5])
Traceback of TorchScript, original code (most recent call last):
.../Global/models/networks.py(774): inference_forward
.../Global/models/NonLocal_feature_mapping_model.py(197): inference_forward
.../anaconda3/lib/python3.8/site-packages/torch/jit/_trace.py(934): trace_module
.../Global/models/mapping_model.py(415): inference
test.py(190): <module>
RuntimeError: shape '[32768, 154]' is invalid for input of size 1474560
inference_forward method:
link
Nevertheless traced module running fine for the image it was traced with.