[JIT] Concatenating two RecursiveScriptModules

Dear community,
I am working in the field of deploying PyTorch models. Our customers provide us models which have been jit.scipt(...)-ed and saved as *.pt files. This is the common model exchange interface they have defined for us.

The “unholy” thing that I’m trying to achieve now is to concatenate two such scripted models in form of RecursiveScriptModules.

model1 = torch.jit.load('customer_model1.pt')
model2 = torch.jit.load('customer_model2.pt')

print(type(model1) # Gives me <class 'torch.jit._script.RecursiveScriptModule'>
print(type(model2) # Gives me <class 'torch.jit._script.RecursiveScriptModule'>

# ???

The questions is whether I can somehow concatenate these two models? Every idea is highly appreciated…

Edit: maybe a side note: we use the models only for inference

Thank you in advance!
Best regards,
RB

Would you like to “concatenate” them such that the output of model1 would be passed to model2 in a sequential way? If so, then you might be able to create a new custom nn.Module, use both of the scripted models there, and script the new “parent model” again (if needed).
Let me know, if I misunderstood your question.

Dear @ptrblck,
thank you for your answer, yes that is exactly what I would like to achieve. However, I’m not sure how to convert even a single RecursiveScriptModule into an nn.Module without it’s original class definition available. I’d highly appreciate some hints please :slight_smile:

Thank you & Best regards,
RB

I had something like this in mind:

# save
model1 = nn.Linear(10, 5)
model2 = nn.Linear(5, 2)

model1 = torch.jit.script(model1)
model2 = torch.jit.script(model2)

torch.jit.save(model1, 'model1.pt')
torch.jit.save(model2, 'model2.pt')


# load
class MyModel(nn.Module):
    def __init__(self, model1, model2):
        super().__init__()
        self.model1 = model1
        self.model2 = model2
        
    def forward(self, x):
        x = self.model1(x)
        x = self.model2(x)
        return x

model1 = torch.jit.load('model1.pt')
model2 = torch.jit.load('model2.pt')
model = MyModel(model1, model2)

x = torch.randn(1, 10)
out = model(x)
print(out.shape)
> torch.Size([1, 2])

I.e. just loading the models and using them wrapped in another parent model.

Dear @ptrblck,
this is exactly what I wanted to achieve, I didn’t know it was possible this way.

Thank you very much,
RB