How to ensemble multiple jit models?

I’m trying to deploy an ensemble model that consists of 5 base models. Sources like this and this provide great examples of building ensemble PyTorch models, but is there a way for us to either:

  • JIT (i.e., jit script and optimize_for_inference) an ensemble model (that is based on vmap); or
  • ensemble multiple JIT models?

I tried the former, which led to an error saying “Compiled functions can’t take variable number of arguments or use keyword-only arguments with defaults”. The latter approach led to an error saying attribute “” does not exist in these jit scripted models.

Any help is appreciated!

Hi jerrybai1995, could you provide us with the source code which you tried to use for the ensemble? What models are you trying to ensemble? Are they already JIT scripted, are they instances of nn.Module, or something else?

I personally ensemble JIT scripted models in the following way:

class Ensemble(torch.nn.Module):
    Ensemble model, wrapping multiple sub-models in separate torch jit threads.
    Allows for parallel computation of the ensemble forward passes. 

    def __init__(self, models: List[torch.nn.Module]):
        self.models = torch.nn.ModuleList(models)

    def forward(self, x: torch.Tensor) -> Dict[str, List[torch.Tensor]]:
        futures = [torch.jit.fork(model, x) for model in self.models]
        results = [torch.jit.wait(fut) for fut in futures]
        result = torch.mean(torch.stack(results), dim=0)
        return result

and I instantiate the ensemble with already JIT scripted models, e.g.,

    models = []
    for m in ['', '', '']:
        model = torch.jit.load(m)

    ensemble = Ensemble(models)

    # Script and save ensemble
    scripted = torch.jit.script(ensemble), '')