Using TorschScript to save a model with multiple heads

My goal is to serialize a pytorch trained model an load it in an environment where the original class defining the neural network is not available. To achieve that, I decided to use TorchScript since it seems the only possible way.

I have a multi-task model (type nn.Module) built using a body common to every task (also nn.Module, just a few linear layers) and a set of linear head models, one per task.
I store the head models in a dictionary Dict[int, nn.Module] called _task_head_models and I created an ad-hoc forward method in my module class to select the right head at prediction time:

    def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
        if task_id not in self._task_head_models.keys():
            raise ValueError(
                f"The task id {task_id} is not valid. Valid task ids are {self._task_head_models.keys()}."
            )

        return self._task_head_models[task_id](self._model(x))

This works fine until I am not trying to serialize it using torchscript.
When I try torch.jit.script(mymodule), I get:

Module 'MyModule' has no attribute '_task_head_models' (This attribute exists on the Python module, but we failed to convert Python type: 'dict' to a TorchScript type. Cannot infer concrete type of torch.nn.Module. Its type was inferred; try adding a type annotation for the attribute.)

Something that seems off, is that my module contains a Dict, not a dict as mentioned in the error message. Forgetting that for a second, it’s still unclear why this is happening. Dictionaries seems to be supported in the language reference: TorchScript Language Reference - PyTorch - W3cubDocs

I also tried to use ModuleDict instead of Dict (changing the key type to str) but that doesn’t seem to work either: Unable to extract string literal index. ModuleDict indexing is only supported with string literals. Enumeration of ModuleDict is supported, e.g. 'for k, v in self.items(): ...':