Using TorschScript to save a model with multiple heads

My goal is to serialize a pytorch trained model an load it in an environment where the original class defining the neural network is not available. To achieve that, I decided to use TorchScript since it seems the only possible way.

I have a multi-task model (type nn.Module) built using a body common to every task (also nn.Module, just a few linear layers) and a set of linear head models, one per task.
I store the head models in a dictionary Dict[int, nn.Module] called _task_head_models and I created an ad-hoc forward method in my module class to select the right head at prediction time:

    def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
        if task_id not in self._task_head_models.keys():
            raise ValueError(
                f"The task id {task_id} is not valid. Valid task ids are {self._task_head_models.keys()}."
            )

        return self._task_head_models[task_id](self._model(x))

This works fine until I am not trying to serialize it using torchscript.
When I try torch.jit.script(mymodule), I get:

Module 'MyModule' has no attribute '_task_head_models' (This attribute exists on the Python module, but we failed to convert Python type: 'dict' to a TorchScript type. Cannot infer concrete type of torch.nn.Module. Its type was inferred; try adding a type annotation for the attribute.)

Something that seems off, is that my module contains a Dict, not a dict as mentioned in the error message. Forgetting that for a second, it’s still unclear why this is happening. Dictionaries seems to be supported in the language reference: TorchScript Language Reference - PyTorch - W3cubDocs

I also tried to use ModuleDict instead of Dict (changing the key type to str) but that doesn’t seem to work either: Unable to extract string literal index. ModuleDict indexing is only supported with string literals. Enumeration of ModuleDict is supported, e.g. 'for k, v in self.items(): ...':

I’ve encountered the same problem and after reading this post https://github.com/pytorch/pytorch/issues/47496 and seeing this code in their tests https://github.com/pytorch/pytorch/blob/2e57b1e8479bb09ea1b0a2de059c1322e6782fdb/test/jit/test_module_containers.py#L431 I figured it out. To get non-static indexing to work you need to define an interface class with the jit decorator. I’m writing it here so others don’t have to go through this pain

Let’s say the heads are of type Head

class Head(nn.Module):
....
  def forward(self, my_x: Tensor)->Tensor:
    return self.do_something(x)

You need to write an interface with the exact same signature (identical names!) and provide the jit.interface decorator

@torch.jit.interface
class OutputHeadInterface(nn.Module):
    def forward(self, my_x: Tensor) -> Tensor: 
      pass

Then, in the forward of your model, using your code:

class MyModule
.....
def forward(self, x: torch.Tensor, task_id: str = "0") -> torch.Tensor:
        if task_id not in self._task_head_models.keys():
            raise ValueError(
                f"The task id {task_id} is not valid. Valid task ids are {self._task_head_models.keys()}."
            )

        output_head: OutputHeadInterface = self._task_head_models[task_id]
        return output_head.forward(self._model(x))

I’ve changed the task_id to a string

1 Like