nn.ModuleList loses None objects inside it after scripting

:bug: Bug

To Reproduce

Executing torch.jit.script over my model is working however it returns a model that fails at runtime.

Looking deeply the nn.ModuleList is loosing None elements from the Modulelist.

Here, above I attach a code for reproducing the error:

import os
import sys
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import transforms
from PIL import Image


class TestBlock(nn.Module):
    def __init__(self):
        super(TestBlock, self).__init__()
        
        layers = []
        layers.append(None)
        layers.append(None)
        layers.append(nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
                               bias=False))
        self.layer = nn.ModuleList(layers)
        
    def forward(self,x):
        for aux in self.layer:
            print("ENTER")
            if aux is not None:
                x = aux(x)
                print("Not None")
        return x

Creating model and tracing it:

model=TestBlock()
traced_cell=torch.jit.script(model)

Testing model with an image:

img = Image.open("test.png")

my_transforms = transforms.Compose([transforms.Resize((1002,1002)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(
                                                        [0.485, 0.456, 0.406],
                                                        [0.229, 0.224, 0.225])])
img_input= my_transforms(img).unsqueeze(0).cpu()

res=model(img_input)

This outputs the next:

ENTER
ENTER
ENTER
Not None

Traced version output:

res=traced_cell(img_input)
ENTER
Not None

Expected behavior

Get same output as original model

You are currently trying to script a tensor in:

traced_cell=torch.jit.script(aux)

so I assume you want to pass model instead to the method?

Try to narrow down the issue, as your current code contains more than 700 lines of code.

Sorry, I paste the wrong code. I am scripting the model with torch.jit.script(model).

I reduced the model to 15lines producing the same error. Look at it again pls @ptrblck

That’s great. Thanks for reducing the code.

I would assume it’s on purpose for the JIT to remove no-ops from the graph, as they won’t do anything.
What’s your use case that you need these None objects in an nn.ModuleList?
As a workaround you could probably just add nn.Identity() modules instead of Nones.

1 Like

This is not my defined Arch, this arch is Microsoft HRNet.

But if you look in his code they are using it to create a new list:

y_list = self.stage2(x_list)

        x_list = []
        for i in range(self.transition2):
            if self.transition2[i] is not None:
                x_list.append(self.transition2[i](y_list[-1]))
            else:
                x_list.append(y_list[i])

self.transition2 is the module list containing None objects

With you workaround C should change is not None with not is isinstance(layer, nn.Identity())
I am going to try your work-around and let you know!

The previous approach work.

However, the model works fast with first image. If we pass an image again it never ends! In addition, TorchScript model is also slower at first image.

Thanks for the update and the code snippet.
The nn.Identity() approach might work, but looks quite hacky given the new code snippet.

However, I’m not familiar with the model, so don’t know which approach would be best to make is scriptable and would suggest to create an issue in their GitHub.

I understand that the eager model is working for a single iteration and hangs in the second one?
While the scripted model is slower in the first iteration and works fine afterwards?

The model works well in eager model. After scripting it, the first iteration is like 6 seconds, second one around 3minutes and third one and go on like 1 second.

I think that it is being optimized. Is there any way of having pre optimizing in a Flask API Rest or disabling optimization? Or Saving optimized one so when it gets loaded is the optimized version?

You could try to use

torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)

at the beginning of your script to disable the optimization.
However, how if your Flask application is running longer than a couple of seconds, the startup time could probably be ignored.

I don’t understand it.

My Flask application has several models loaded. It makes inference over the same image with different models, the models names are given in a list.

Since the second iteration seems to run for 3 minutes, I would ask you to create an issue here, as it doesn’t seem right.

If the first iteration would be a warmup time (ignoring the 3 minutes, which seems to be a bug), then you would only pay the cost once. Every other time the prediction would use the optimized graph and should be fast.

Ah okey. Thank you very much for the info and all your help in the forums!

What issue should I submit? Upload traced model or the model definition code?

I have an issue open for None objects dissapearing fron inside nn.ModuleList

If possible, a minimal code snippet, which is executable and shows the JIT behavior, where the second iteration takes 3 minutes, while the first one finishes in 6 seconds.

I don’t know if I will be able to paste a minimal snippet. This post was created with the part of nn.ModuleList containing Nones.

I have tried to measure times again:

Withouth tracing:

CPU times: user 12.8 s, sys: 1.41 s, total: 14.2 s
Wall time: 2.32 s

Scripted First Iteration:

CPU times: user 15 s, sys: 1.77 s, total: 16.8 s
Wall time: 4.64 s

Scripted Second Iteration:

CPU times: user 5min 8s, sys: 1.14 s, total: 5min 9s
Wall time: 5min

Scripted Third Iteration:

CPU times: user 11 s, sys: 1.18 s, total: 12.2 s
Wall time: 2.02 s

Done.