Hi folks,
I have a model contains ModuleDict which requires key to be string type. But channel_id is int type. Thus, inside forward() function, I use str(channel_id) to convert from int type to string type and then use it as ModuleDict’s key. This is fine.
Problem arises when converting the model into C++ using trace, the accepted argument MUST be a Tensor which does NOT support string data type. Thus, I use channel_id as a Tensor of int64. However, this channel_id is a Tensor with single int64 value, while in the forward function of model class, the channel_id argument is a single int64 value instead of a Tensor.
Of course, we can modify the model class’s forward function to accept channel_id as tensor instead of a single int64 value just to be able to trace the model, but it needs a lot of changes.
Is there a better way to keep the existing model code, and able to trace it?
channel_id = torch.ones(1, dtype=torch.int64)
traced_script_module = torch.jit.trace(model, (premise, premise_length, hypotheses, hypotheses_length, channel_id))