Determine forward() call chain of RecursiveScriptModule

Hello,
I’d like to determine the exact layer execution order (forward() call chain) of a RecursiveScriptModule. I have to do it on this particular data type, as this is the pre-defined interface for model exchange in our flow.

Example: dummy Fashion-MNIST model, with a twist - I have swapped the declaration order of layer1 and layer2, whereas the forward() function contains the correct execution order.

class FashionCNN(nn.Module):

    def __init__(self):
        super(FashionCNN, self).__init__()

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc1 = nn.Linear(in_features=64*6*6, out_features=600)
        self.drop = nn.Dropout2d(0.25)
        self.fc2 = nn.Linear(in_features=600, out_features=120)
        self.fc3 = nn.Linear(in_features=120, out_features=10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.drop(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

...
model = FashionCNN()
...
scripted_model = torch.jit.script(model.eval())
scripted_model.save('fmnist_scripted.pt')

In another script, after loading back the model, I would like to determine the forward() call chain programmatically somehow:

...
scripted_model = torch.jit.load('fmnist_scripted.pt')

# Iterate over all hierarchycal layers ?!
layers = OrderedDict()
for i in scripted_model.named_modules():
    if not list(i[1].named_children()):
        layers[i[0]] = i[1]

# This mechanism gives me the wrong layer ordering:
print(layers)
> OrderedDict([('layer2.0', RecursiveScriptModule(original_name=Conv2d)),
             ('layer2.1', RecursiveScriptModule(original_name=BatchNorm2d)),
             ('layer2.2', RecursiveScriptModule(original_name=ReLU)),
             ('layer2.3', RecursiveScriptModule(original_name=MaxPool2d)),
             ('layer1.0', RecursiveScriptModule(original_name=Conv2d)),
             ('layer1.1', RecursiveScriptModule(original_name=BatchNorm2d)),
             ('layer1.2', RecursiveScriptModule(original_name=ReLU)),
             ('layer1.3', RecursiveScriptModule(original_name=MaxPool2d)),
             ('fc1', RecursiveScriptModule(original_name=Linear)),
             ('drop', RecursiveScriptModule(original_name=Dropout2d)),
             ('fc2', RecursiveScriptModule(original_name=Linear)),
             ('fc3', RecursiveScriptModule(original_name=Linear))])

Is there a way to obtain the right order of the forward() function call chain? Can this be done somehow programmatically?

Maybe @ptrblck, could you please give me some hints? I’d really appreciate your help!

Thank you & Best regards,
RB