Torch.export does not support specify partial output

I have a two-tower model with outputs A and B, and during training I need to train the whole model structure, and only activate only one part of model during inference.

For example, at request 1, I need to run only the subgraphs for output A, at request 2 I need to run only the subgraphs for output B.

TensorFlow allows multiple graph signature to support multiple subgraphs, so that I can create different subgraphs automatically at runtime. However, PyTorch does not have such capability yet.

One alternative is to export twice for each part, but that is non-trivial and will cause redundancy for shared parameters. Is there better method?

I’m not sure which capability you are exactly looking for. In PyTorch you can simply write your forward pass as you wish and activate only the part of the model you want. Could you describe your use case and functionality a bit more?

Sure, consider the following example

class Model(nn.Module):
    def __init__(self):
        self.ln = nn.Linear(20, 10)
        # output tower A
        self.ln_x1 = nn.Linear(10, 100)
        self.ln_x2 = nn.Linear(100, 1)

        # output tower B
        self.ln_x3 = nn.Linear(10, 1)

    def forward(self, x: torch.Tensor, signature: int):
        x = self.ln(x)
        if signature == 1: # tower A
            x = self.ln_x1(x)
            x = self.ln_x2(x)
        elif signature == 2: # tower B
            x = self.ln_x3(x)
        return x

Now I want to export+AOT compile the model as a whole artifact, and during inference the value of signature decides which path I go. I can think of two approaches

  1. Trace signature by tensor, but that would cause graph break(due to conditional branch) and might cause AOT failure.
  2. Export two towers separately, but the parameter self.ln is duplicated within two models.

In TensorFlow, you can export only 1 savedmodel with two graph signatures, one with tower A output and one with tower B output. The runtime will deal with the subgraph of each signature.

Is there any equivalent for such solution?