[jit] Performing Graph Surgery on RecursiveScriptModules

Dear community,
I am working on deploying PyTorch models. Our customers provide us models which have been jit.scipt(...)-ed and saved as *.pt files. This is the common model exchange interface they have defined for us.

We are now trying to extract subgraphs of the larger full model graph. I could already carve out single layers from a model, e.g.:

scripted_model = torch.jit.load('sc_model.pt')
for i in scripted_model.named_modules():

    # Avoid hierarchical modules - since their modules are represented as single standalone layers as well
    if sum(1 for _ in i[1].named_children()) != 0:
        continue

    # Save layers one-by-one                                                                                                                                                                                                                                          
    i[1].save(f"sc_model_{i[0]}.pt")

The question: is there a ways to carve out multiple subsequent layers in a single *.pt file? Is this even possible?

Of course the validity of the subnet model is important, so that we can call it for inference as normally as the full model would be invoked.

I’d really appreciate some help here…
Thank you & Best regards
RB