My goal is to serialize a pytorch trained model an load it in an environment where the original class defining the neural network is not available. To achieve that, I decided to use TorchScript since it seems the only possible way.
I have a multi-task model (type nn.Module) built using a body common to every task (also nn.Module, just a few linear layers) and a set of linear head models, one per task.
I store the head models in a dictionary Dict[int, nn.Module]
called _task_head_models
and I created an ad-hoc forward method in my module class to select the right head at prediction time:
def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
if task_id not in self._task_head_models.keys():
raise ValueError(
f"The task id {task_id} is not valid. Valid task ids are {self._task_head_models.keys()}."
)
return self._task_head_models[task_id](self._model(x))
This works fine until I am not trying to serialize it using torchscript.
When I try torch.jit.script(mymodule)
, I get:
Module 'MyModule' has no attribute '_task_head_models' (This attribute exists on the Python module, but we failed to convert Python type: 'dict' to a TorchScript type. Cannot infer concrete type of torch.nn.Module. Its type was inferred; try adding a type annotation for the attribute.)
Something that seems off, is that my module contains a Dict
, not a dict
as mentioned in the error message. Forgetting that for a second, it’s still unclear why this is happening. Dictionaries seems to be supported in the language reference: TorchScript Language Reference - PyTorch - W3cubDocs
I also tried to use ModuleDict instead of Dict (changing the key type to str) but that doesn’t seem to work either: Unable to extract string literal index. ModuleDict indexing is only supported with string literals. Enumeration of ModuleDict is supported, e.g. 'for k, v in self.items(): ...':